fastAPI_CatVTON / test.py
nrtoya's picture
0152
bda551d
import argparse
import os
os.environ['CUDA_HOME'] = '/usr/local/cuda'
os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
from datetime import datetime
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import snapshot_download
from PIL import Image
torch.jit.script = lambda f: f
from model.cloth_masker import AutoMasker, vis_mask
from model.pipeline import CatVTONPipeline
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
# ํŒจํ‚ค์ง€ ์ถ”๊ฐ€
import cv2
'''
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--base_model_path",
type=str,
default="booksforcharlie/stable-diffusion-inpainting",
# default="runwayml/stable-diffusion-inpainting",
help=(
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
),
)
parser.add_argument(
"--resume_path",
type=str,
default="zhengchong/CatVTON",
help=(
"The Path to the checkpoint of trained tryon model."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="resource/demo/output",
help="The output directory where the model predictions will be written.",
)
parser.add_argument(
"--width",
type=int,
default=768,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--height",
type=int,
default=1024,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--repaint",
action="store_true",
help="Whether to repaint the result image with the original background."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
default=True,
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
args = parse_args()
'''
RESUME_PATH = os.getenv("RESUME_PATH", "zhengchong/CatVTON")
repo_path = snapshot_download(repo_id=RESUME_PATH)
# AutoMasker
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
automasker = AutoMasker(
densepose_ckpt=os.path.join(repo_path, "DensePose"),
schp_ckpt=os.path.join(repo_path, "SCHP"),
device='cuda',
)
person_image = Image.open("./resource/demo/example/person/men/model_7.png").convert("RGB")
mask = automasker(
person_image,
'upper'
)['mask'] # ์—ฌ๊ธฐ์„œ ๋ฆฌํ„ด๋˜๋Š” mask๋Š” PIL ์ด๋ฏธ์ง€์ž„.(cloth_masker.py ์ฐธ์กฐ) # ์ฐธ๊ณ ๋กœ ['densepose']๋กœ densepose๋„ ํ™•์ธ๊ฐ€๋Šฅ.
### ์—ฌ๊ธฐ์„œ mask modify์— ์‚ฌ์šฉ๋œ ์ฝ”๋“œ๋ฅผ app.py์— ์ฒดํฌํ•ด๋†“์€ ๋ถ€๋ถ„์— ์ถ”๊ฐ€ํ•˜๋ฉด ๋œ๋‹ค!
def remove_bottom_part(mask: np.ndarray, y_threshold: int):
"""
์ด๋ฏธ์ง€์˜ y_threshold ์•„๋ž˜์— ์žˆ๋Š” ๋ถ€๋ถ„์„ ์‚ญ์ œ.
:param mask: ์ž…๋ ฅ ๋งˆ์Šคํฌ (numpy ๋ฐฐ์—ด)
:param y_threshold: ์ œ๊ฑฐํ•  Y ์ขŒํ‘œ ๊ฐ’
:return: ์ˆ˜์ •๋œ ๋งˆ์Šคํฌ (numpy ๋ฐฐ์—ด)
"""
# y_threshold ์•„๋ž˜์˜ ๋ชจ๋“  ํ”ฝ์…€์„ 0์œผ๋กœ ์„ค์ •
mask[y_threshold:, :] = 0
return Image.fromarray(mask)
# closing ์—ฐ์‚ฐ / fitting_mode๊ฐ€ standard ๋‚˜ loose ์ผ๋•Œ๋งŒ ์‚ฌ์šฉํ•˜๊ธฐ
def morph_close(mask):
mask_np = np.array(mask)
kernel = np.ones((30, 30), np.uint8) # ์ปค์งˆ์ˆ˜๋ก ์ž˜ ์—ฐ๊ฒฐ๋จ
closed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_CLOSE, kernel)
return Image.fromarray(closed_mask)
# opening ์—ฐ์‚ฐ / fitting_mode๊ฐ€ standard ๋‚˜ loose ์ผ๋•Œ๋งŒ ์‚ฌ์šฉํ•˜๊ธฐ
def morph_open(mask):
mask_np = np.array(mask)
kernel = np.ones((30, 30), np.uint8) # ์ปค์งˆ์ˆ˜๋ก ์ž˜ ์‚ฌ๋ผ์ง
#closed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_CLOSE, kernel)
opened_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel) #opened_mask๋Š” numpy ์—ฐ์‚ฐ ๊ฒฐ๊ณผ ์ด๋ฏ€๋กœ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ ํ•„์š”
return Image.fromarray(opened_mask)
def morph_open2(mask):
mask_np = np.array(mask)
kernel = np.ones((10, 10), np.uint8) # ์ปค์งˆ์ˆ˜๋ก ์ž˜ ์‚ฌ๋ผ์ง
#closed_mask = cv2.morphologyEx(mask_np, cv2.MORPH_CLOSE, kernel)
opened_mask = cv2.morphologyEx(mask_np, cv2.MORPH_OPEN, kernel) #opened_mask๋Š” numpy ์—ฐ์‚ฐ ๊ฒฐ๊ณผ ์ด๋ฏ€๋กœ PIL ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜ ํ•„์š”
return Image.fromarray(opened_mask)
## opened_mask = morph_open(mask)
## opened_mask.save('./opened_mask.png') #opened_mask๋Š” PIL ์ด๋ฏธ์ง€ ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜๋˜์—ˆ์œผ๋ฏ€๋กœ (Image.fromarray()์‚ฌ์šฉํ•ด์„œ) .save๋ฅผ ๋ฐ”๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.
#opened_mask2 = morph_open2(mask)
#kernel = np.ones((50, 50), np.uint8)
#opened_mask2 = cv2.dilate(np.array(opened_mask2), kernel, iterations=1)
#opened_mask2 = Image.fromarray(opened_mask2)
#opened_mask2 = mask_processor.blur(opened_mask2, blur_factor=9)
#opened_mask2.save('./opened_mask2.png')
# mask = mask_processor.blur(mask, blur_factor=9)
## mask.save("./test_mask.png") # ๋งˆ์Šคํฌ๋ฅผ PNG ํŒŒ์ผ๋กœ ์ €์žฅ
## masked_person = vis_mask(person_image, mask) # app.py์—์„œ๋„ blur ์ฒ˜๋ฆฌ ํ•œ ๋‹ค์Œ์— vis_mask ๋ฉ”์„œ๋“œ ํ˜ธ์ถœํ•จ.
## masked_person.save("./test_masked_person.png") # ๋งˆ์Šคํฌ์™€ target img๊ฐ€ ํ•ฉ์ณ์ง„ ์‚ฌ์ง„์„ PNG ํŒŒ์ผ๋กœ ์ €์žฅ
# mask์˜ y์ถ• ์Œ์˜ ๋ฐฉํ–ฅ ์ด๋™
def extend_mask_downward(mask_image: np.ndarray, pixels: int) -> np.ndarray:
"""
y์ถ• ์Œ์˜ ๋ฐฉํ–ฅ์œผ๋กœ (์•„๋ž˜๋กœ) ๋งˆ์Šคํฌ ์ด๋ฏธ์ง€๋ฅผ ํ™•์žฅํ•˜๋Š” ํ•จ์ˆ˜.
:param mask_image: ๋งˆ์Šคํฌ ์ด๋ฏธ์ง€ (numpy ๋ฐฐ์—ด)
:param pixels: ํ™•์žฅํ•  ํ”ฝ์…€ ์ˆ˜
:return: ํ™•์žฅ๋œ ๋งˆ์Šคํฌ ์ด๋ฏธ์ง€ (numpy ๋ฐฐ์—ด)
"""
# ์ด์ง„ํ™”๋œ ๋งˆ์Šคํฌ๋ฅผ ๋งŒ๋“ฆ
mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)[1]
# ํ™•์žฅ์„ ์œ„ํ•œ ์ปค๋„. y์ถ•์œผ๋กœ๋งŒ ํ™•์žฅํ•˜๊ธฐ ์œ„ํ•ด ์„ธ๋กœ ๊ธธ์ด๋ฅผ ํฌ๊ฒŒ ์„ค์ •ํ•จ
kernel = np.zeros((pixels, 1), np.uint8) # y์ถ•์œผ๋กœ๋งŒ ๊ธธ์–ด์ง„ ์ปค๋„
# y์ถ• ์Œ์˜ ๋ฐฉํ–ฅ์œผ๋กœ๋งŒ ํ™•์žฅ (cv2.dilate ์‚ฌ์šฉ)
extended_mask = cv2.dilate(mask, kernel, iterations=1)
return Image.fromarray(extended_mask)
def image_equal(img1, img2):
return np.array_equal(np.array(img1), np.array(img2))
# automasker์— ๋Œ€ํ•œ fitting ์ •๋„๋ฅผ default ์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ ์ฝ”๋“œ ์ถ”๊ฐ€.
def extend_mask_downward2(mask_image: np.ndarray, pixels: int) -> Image.Image:
# ์ด์ง„ํ™”๋œ ๋งˆ์Šคํฌ๋ฅผ ๋งŒ๋“ฆ
mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)[1]
height, width = mask.shape[:2]
# ๋งˆ์Šคํฌ๊ฐ€ ์‹œ์ž‘ํ•˜๋Š” y ์ขŒํ‘œ์™€ ๋๋‚˜๋Š” y ์ขŒํ‘œ ์ฐพ๊ธฐ
non_zero_rows = np.where(mask.any(axis=1))[0]
if len(non_zero_rows) == 0:
# ๋งˆ์Šคํฌ๊ฐ€ ๋น„์–ด์žˆ๋Š” ๊ฒฝ์šฐ, ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋ฐ˜ํ™˜
return Image.fromarray(mask.copy())
mask_start_row = non_zero_rows[0]
mask_end_row = non_zero_rows[-1]
# ๋Š˜์–ด๋‚œ ๋งˆ์Šคํฌ์˜ ๋ y ์ขŒํ‘œ ๊ณ„์‚ฐ (์ด๋ฏธ์ง€ ๋†’์ด๋ฅผ ์ดˆ๊ณผํ•˜์ง€ ์•Š๋„๋ก)
new_mask_end_row = min(mask_end_row + pixels, height - 1)
# ์›๋ณธ ๋งˆ์Šคํฌ ์˜์—ญ๊ณผ ๋Š˜์–ด๋‚œ ๋งˆ์Šคํฌ ์˜์—ญ์˜ ๊ธธ์ด ๊ณ„์‚ฐ
original_mask_length = mask_end_row - mask_start_row + 1
stretched_mask_length = new_mask_end_row - mask_start_row + 1
# y ์ขŒํ‘œ ๋งคํ•‘ ๋ฐฐ์—ด ์ƒ์„ฑ
map_y = np.arange(height, dtype=np.float32)
# ๋งˆ์Šคํฌ ์˜์—ญ์— ๋Œ€ํ•œ y ์ขŒํ‘œ ์žฌ๋งคํ•‘ (์„ ํ˜•์ ์œผ๋กœ ๋Š˜๋ฆฌ๊ธฐ)
if stretched_mask_length > 1:
map_y[mask_start_row:new_mask_end_row + 1] = np.linspace(
mask_start_row,
mask_end_row,
stretched_mask_length
)
else:
map_y[mask_start_row:new_mask_end_row + 1] = mask_start_row
# ๋งˆ์Šคํฌ ์•„๋ž˜ ์˜์—ญ์€ ๋งˆ์ง€๋ง‰ ๋งˆ์Šคํฌ ํ–‰์œผ๋กœ ๋งคํ•‘
if new_mask_end_row + 1 < height:
map_y[new_mask_end_row + 1:] = mask_end_row
# map_y์˜ ํฌ๊ธฐ๋ฅผ (height, width)๋กœ ํ™•์žฅ
map_y = np.repeat(map_y[:, np.newaxis], width, axis=1)
# x ์ขŒํ‘œ๋Š” ๊ทธ๋Œ€๋กœ ์œ ์ง€
map_x = np.tile(np.arange(width, dtype=np.float32), (height, 1))
# ์ด๋ฏธ์ง€ ๋ฆฌ๋งคํ•‘
extended_mask = cv2.remap(
mask,
map_x,
map_y,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0
)
return Image.fromarray(extended_mask)
# ๋งˆ์Šคํฌ๋ฅผ y์ถ• ์Œ์˜ ๋ฐฉํ–ฅ์œผ๋กœ 50ํ”ฝ์…€ ํ™•์žฅ
## extended_mask = extend_mask_downward(np.array(mask), pixels=100)
# ํ™•์žฅ๋œ ๋งˆ์Šคํฌ ์ €์žฅ
## extended_mask.save('extended_mask_image.png')
# ์ตœ์ข… ๋งˆ์Šคํฌ ์ €์žฅ
# fitting ์ •๋„์— ๋”ฐ๋ผ, extended_mask ํ•จ์ˆ˜ ํ˜ธ์ถœ ๋ณ€์ˆ˜์ธ pixels๋ฅผ ์กฐ์ ˆํ•˜๋ฉด ๋œ๋‹ค.
# ์ •ํ™•๋„๋ฅผ ์œ„ํ•ด ๊ทธ๋ƒฅ dilation ํ•˜์ง€ ์•Š๊ณ , y์ขŒํ‘œ๊ฐ€ ์•ฝ๊ฐ„ ๋‹ค๋ฅธ ๋‘ ๋งˆ์Šคํฌ๋ฅผ ํ•ฉ์ณค๋‹ค.
## final_mask = Image.fromarray(np.array(opened_mask) | np.array(extended_mask))
## final_mask = morph_close(morph_open(final_mask)) #๋ถˆํ•„์š”ํ•œ ๋™๋–จ์–ด์ง„ ๋ถ€๋ถ„ ์‚ญ์ œ -> ์—ฐ๊ฒฐ๋˜์ง€ ์•Š์€ ๋ถ€๋ถ„ ์—ฐ๊ฒฐ
## final_mask.save('final_mask_image.png')
## masked_person2 = vis_mask(person_image, final_mask) # app.py์—์„œ๋„ blur ์ฒ˜๋ฆฌ ํ•œ ๋‹ค์Œ์— vis_mask ๋ฉ”์„œ๋“œ ํ˜ธ์ถœํ•จ.
## masked_person2.save("./test_masked_person2.png") # ๋งˆ์Šคํฌ์™€ target img๊ฐ€ ํ•ฉ์ณ์ง„ ์‚ฌ์ง„์„ PNG ํŒŒ์ผ๋กœ ์ €์žฅ
#person_image = Image.open("path_to_image").convert("RGB")
#standard_image = Image.open("./resource/demo/example/person/men/m_lvl3.png").convert("RGB")
"""
compare_image_mlvl3 = Image.open("./resource/demo/example/person/men/m_lvl3.png").convert("RGB")
compare_image_mlvl3 = resize_and_crop(compare_image_mlvl3, (args.width, args.height))
person_image2 = Image.open("./resource/demo/example/person/men/m_lvl0.png").convert("RGB") # ์ด๊ฑธ ์–ด๋Š bmi ๋ ˆ๋ฒจ์„ ๊ธฐ์ค€์œผ๋กœ ์“ธ์ง€๋Š” ๋ญ.. ์‹คํ—˜ํ•ด๋ณด๋ฉด์„œ ์ œ์ผ ์ข‹์€ ๊ฑฐ ์ •ํ•˜๋ฉด ๋จ.
person_image2 = resize_and_crop(person_image2, (args.width, args.height))
mask = automasker(
person_image2,
"upper"
)['mask']
mask.save("./first_mask.png")
# ์ดํ›„ ์ฒ˜๋ฆฌ
sam_mask_lower = Image.open("./resource/demo/example/person/sam/m_lvl3_lower_sam.png").convert("L")
sam_mask_lower = resize_and_crop(sam_mask_lower, (args.width, args.height))
sam_mask_upper = Image.open("./resource/demo/example/person/sam/m_lvl3_upper_sam.png").convert("L")
sam_mask_upper = resize_and_crop(sam_mask_upper, (args.width, args.height))
mask_np = np.array(mask)
sam_mask_upper_np = np.array(sam_mask_upper)
sam_mask_lower_np = np.array(sam_mask_lower)
kernel = np.ones((10, 10), np.uint8)
sam_mask_upper_np = cv2.dilate(sam_mask_upper_np, kernel, iterations=1)
sam_mask_lower_np = cv2.dilate(sam_mask_lower_np, kernel, iterations=1)
result_np = np.where(sam_mask_lower_np== 255, 0, mask_np)
result_np = np.where(sam_mask_upper_np== 255, 255, result_np)
mask = Image.fromarray(result_np)
mask.save("./last_mask2.png")
"""