giorgio-caparvi
passing caption through json
e9c985d
import os
import io
# external libraries
import torch
import torch.utils.checkpoint
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
# custom imports
from model.src.datasets.dresscode import DressCodeDataset
from model.src.datasets.vitonhd import VitonHDDataset
from model.src.mgd_pipelines.mgd_pipe import MGDPipe
from model.src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
from model.src.utils.arg_parser import eval_parse_args
from model.src.utils.image_from_pipe import generate_images_from_mgd_pipe
from model.src.utils.set_seeds import set_seed
from PIL import Image
from huggingface_hub import HfApi, HfFolder
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__, log_level="INFO")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_START_METHOD"] = "thread"
hf_token = os.getenv("HF_TOKEN")
api = HfApi()
HfFolder.save_token(hf_token)
def main(json_from_req: dict) -> None:
args = eval_parse_args()
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
device = accelerator.device
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Load scheduler, tokenizer and models.
val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
val_scheduler.set_timesteps(50, device=device)
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = load_mgd_model(dataset=args.dataset, pretrained=True)
#unet = torch.hub.load(dataset=args.dataset, repo_or_dir='aimagelab/multimodal-garment-designer', source='github',
#model='mgd', pretrained=True)
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Enable memory efficient attention if requested
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if args.category:
category = [args.category]
else:
category = ['dresses', 'upper_body', 'lower_body']
if args.dataset == "dresscode":
test_dataset = DressCodeDataset(
dataroot_path=args.dataset_path,
phase='test',
order=args.test_order,
radius=5,
sketch_threshold_range=(20, 20),
tokenizer=tokenizer,
category=category,
size=(512, 384),
json_from_req=json_from_req
)
elif args.dataset == "vitonhd":
test_dataset = VitonHDDataset(
dataroot_path=args.dataset_path,
phase='test',
order=args.test_order,
sketch_threshold_range=(20, 20),
radius=5,
tokenizer=tokenizer,
size=(512, 384),
json_from_req=json_from_req
)
else:
raise NotImplementedError
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers_test,
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if args.mixed_precision == 'fp16':
weight_dtype = torch.float16
# Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
unet.eval()
# Select fast classifier free guidance or disentagle classifier free guidance according to the disentagle parameter in args
with torch.inference_mode():
if args.disentagle:
val_pipe = MGDPipeDisentangled(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=val_scheduler,
).to(device)
else:
val_pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=val_scheduler,
).to(device)
val_pipe.enable_attention_slicing()
test_dataloader = accelerator.prepare(test_dataloader)
final_image = generate_images_from_mgd_pipe(
test_order=args.test_order,
pipe=val_pipe,
test_dataloader=test_dataloader,
save_name=args.save_name,
dataset=args.dataset,
output_dir=args.output_dir,
guidance_scale=args.guidance_scale,
guidance_scale_pose=args.guidance_scale_pose,
guidance_scale_sketch=args.guidance_scale_sketch,
sketch_cond_rate=args.sketch_cond_rate,
start_cond_rate=args.start_cond_rate,
no_pose=False,
disentagle=False,
seed=args.seed,
)
return final_image # Now returning the generated image
def load_mgd_model(dataset: str, pretrained: bool = True) -> UNet2DConditionModel:
"""
MGD model
pretrained (bool): load pretrained weights into the model
"""
config = UNet2DConditionModel.load_config("benjamin-paine/stable-diffusion-v1-5-inpainting", subfolder="unet")
config['in_channels'] = 28
unet = UNet2DConditionModel.from_config(config)
if pretrained:
checkpoint = f"https://github.com/aimagelab/multimodal-garment-designer/releases/download/weights/{dataset}.pth"
unet.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True))
return unet
if __name__ == "__main__":
main()