CatVTON2 / test.py
minjung-s
0307
c5c6bad
import argparse
import os
os.environ['CUDA_HOME'] = '/usr/local/cuda'
os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
from datetime import datetime
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import snapshot_download
from PIL import Image
torch.jit.script = lambda f: f
from model.cloth_masker2 import AutoMasker, vis_mask
from model.pipeline import CatVTONPipeline
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
# ํŒจํ‚ค์ง€ ์ถ”๊ฐ€
import cv2
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--base_model_path",
type=str,
default="booksforcharlie/stable-diffusion-inpainting",
# default="runwayml/stable-diffusion-inpainting",
help=(
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
),
)
parser.add_argument(
"--resume_path",
type=str,
default="zhengchong/CatVTON",
help=(
"The Path to the checkpoint of trained tryon model."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="resource/demo/output",
help="The output directory where the model predictions will be written.",
)
parser.add_argument(
"--width",
type=int,
default=768,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--height",
type=int,
default=1024,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--repaint",
action="store_true",
help="Whether to repaint the result image with the original background."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
default=True,
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
args = parse_args()
repo_path = snapshot_download(repo_id=args.resume_path)
# AutoMasker
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
automasker = AutoMasker(
densepose_ckpt=os.path.join(repo_path, "DensePose"),
schp_ckpt=os.path.join(repo_path, "SCHP"),
device='cuda',
)
person_image = Image.open("./resource/demo/example/person/men/m_lvl0.png").convert("RGB")
mask = automasker(
person_image,
'short dress'
)['mask'] # ์—ฌ๊ธฐ์„œ ๋ฆฌํ„ด๋˜๋Š” mask๋Š” PIL ์ด๋ฏธ์ง€์ž„.(cloth_masker.py ์ฐธ์กฐ) # ์ฐธ๊ณ ๋กœ ['densepose']๋กœ densepose๋„ ํ™•์ธ๊ฐ€๋Šฅ.
### ์—ฌ๊ธฐ์„œ mask modify์— ์‚ฌ์šฉ๋œ ์ฝ”๋“œ๋ฅผ app.py์— ์ฒดํฌํ•ด๋†“์€ ๋ถ€๋ถ„์— ์ถ”๊ฐ€ํ•˜๋ฉด ๋œ๋‹ค!
# mask = ์—ฐ์‚ฐ๊ฒฐ๊ณผ๋ฅผ ๊ฑฐ์นœ ๋งˆ์Šคํฌ
# mask = mask_processor.blur(mask, blur_factor=9)
masked_person = vis_mask(person_image, mask) # app.py์—์„œ๋„ blur ์ฒ˜๋ฆฌ ํ•œ ๋‹ค์Œ์— vis_mask ๋ฉ”์„œ๋“œ ํ˜ธ์ถœํ•จ.
mask.save("./test_mask.png") # ๋งˆ์Šคํฌ๋ฅผ PNG ํŒŒ์ผ๋กœ ์ €์žฅ
masked_person.save("./test_masked_person.png") # ๋งˆ์Šคํฌ์™€ target img๊ฐ€ ํ•ฉ์ณ์ง„ ์‚ฌ์ง„์„ PNG ํŒŒ์ผ๋กœ ์ €์žฅ