fptvton1 / app.py
basso4's picture
Update app.py
0ab6f88 verified
raw
history blame
11.4 kB
import os
import sys
sys.path.append('./')
import numpy as np
import argparse
import torch
import torchvision
import pytorch_lightning
from torch import autocast
from torchvision import transforms
from pytorch_lightning import seed_everything
from einops import rearrange
from functools import partial
from omegaconf import OmegaConf
from PIL import Image
from typing import List
import matplotlib.pyplot as plt
import gradio as gr
import apply_net
from torchvision.transforms.functional import to_pil_image
# from tools.mask_vitonhd import get_img_agnostic
from utils_mask import get_mask_location
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from ldm.util import instantiate_from_config, get_obj_from_str
from ldm.models.diffusion.ddim import DDIMSampler
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
def un_norm(x):
return (x+1.0)/2.0
def un_norm_clip(x):
x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
return x
class DataModuleFromConfig(pytorch_lightning.LightningDataModule):
def __init__(self,
batch_size,
test=None,
wrap=False,
shuffle=False,
shuffle_test_loader=False,
use_worker_init_fn=False):
super().__init__()
self.batch_size = batch_size
self.num_workers = batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
self.wrap = wrap
self.datasets = instantiate_from_config(test)
self.dataloader = torch.utils.data.Dataloader(self.datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=shuffle,
use_worker_init_fn=None)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Script for demo model")
parser.add_argument("-b", "--base", type=str, default=r"configs/test_vitonhd.yaml")
parser.add_argument("-c", "--ckpt", type=str, default=r"ckpt/hitonhd.ckpt")
parser.add_argument("-s", "--seed", type=str, default=42)
parser.add_argument("-d", "--ddim", type=str, default=64)
opt = parser.parse_args()
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.base}")
# data = instantiate_from_config(config.data)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
model.cuda()
model.eval()
model = model.to(device)
sampler = DDIMSampler(model)
precision_scope = autocast
@spaces.GPU
def start_tryon(human_img,garm_img):
#load human image
human_img = human_img['background'].convert("RGB").resize((768,1024))
#mask
tensor_transfrom = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
parsing_model = Parsing(0)
openose_model = OpenPose(0)
openose_model.preprocessor.body_estimation.model.to(device)
keypoints = openose_model(human_img.resize((384,512)))
model_parse, _ = parsing_model(human_img.resize((384,512)))
mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
mask = mask.resize((768, 1024))
mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
mask_gray = to_pil_image((mask_gray+1.0)/2.0)
# mask_gray.save(r'D:\Capstone_Project\cat_dm\gradio_demo\output\maskgray_output.png')
#densepose
human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
args = apply_net.create_argument_parser().parse_args(('show',
'./configs/configs_densepose/densepose_rcnn_R_50_FPN_s1x.yaml',
'./ckpt/densepose/model_final_162be9.pkl',
'dp_segm', '-v',
'--opts',
'MODEL.DEVICE',
'cuda'))
# verbosity = getattr(args, "verbosity", None)
pose_img = args.func(args,human_img_arg)
pose_img = pose_img[:,:,::-1]
pose_img = Image.fromarray(pose_img).resize((768,1024))
#preprocessing image
human_img = human_img.convert("RGB").resize((512, 512))
human_img = torchvision.transforms.ToTensor()(human_img)
garm_img = garm_img.convert("RGB").resize((224, 224))
garm_img = torchvision.transforms.ToTensor()(garm_img)
mask = mask.convert("L").resize((512,512))
mask = torchvision.transforms.ToTensor()(mask)
mask = 1-mask
pose_img = pose_img.convert("RGB").resize((512, 512))
pose_img = torchvision.transforms.ToTensor()(pose_img)
#Normalize
human_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(human_img)
garm_img = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))(garm_img)
pose_img = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(pose_img)
#create inpaint & hint
inpaint = human_img * mask
hint = torchvision.transforms.Resize((512, 512))(garm_img)
hint = torch.cat((hint, pose_img), dim=0)
# {"human_img": human_img, # [3, 512, 512]
# "inpaint_image": inpaint, # [3, 512, 512]
# "inpaint_mask": mask, # [1, 512, 512]
# "garm_img": garm_img, # [3, 224, 224]
# "hint": hint, # [6, 512, 512]
# }
with torch.no_grad():
with precision_scope("cuda"):
#loading data
inpaint = inpaint.unsqueeze(0).to(torch.float16).to(device)
reference = garm_img.unsqueeze(0).to(torch.float16).to(device)
mask = mask.unsqueeze(0).to(torch.float16).to(device)
hint = hint.unsqueeze(0).to(torch.float16).to(device)
truth = human_img.unsqueeze(0).to(torch.float16).to(device)
#data preprocessing
encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
test_model_kwargs = {}
test_model_kwargs['inpaint_image'] = z_inpaint
test_model_kwargs['inpaint_mask'] = mask_resize
shape = (model.channels, model.image_size, model.image_size)
#predict
samples, _ = sampler.sample(S=opt.ddim,
batch_size=1,
shape=shape,
pose=hint,
conditioning=reference,
verbose=False,
eta=0,
test_model_kwargs=test_model_kwargs)
samples = 1. / model.scale_factor * samples
x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image=x_samples_ddim
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
# Xử lý và trả về img và img_C
mask = mask.cpu().permute(0, 2, 3, 1).numpy()
mask = torch.from_numpy(mask).permute(0, 3, 1, 2)
truth = torch.clamp((truth + 1.0) / 2.0, min=0.0, max=1.0)
truth = truth.cpu().permute(0, 2, 3, 1).numpy()
truth = torch.from_numpy(truth).permute(0, 3, 1, 2)
x_checked_image_torch_C = x_checked_image_torch * (1 - mask) + truth.cpu() * mask
x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512, 384])
x_checked_image_torch_C = torch.nn.functional.interpolate(x_checked_image_torch_C.float(), size=[512, 384])
# Chuyển đổi từ torch.Tensor sang PIL Image
to_pil = transforms.ToPILImage()
img = to_pil(x_checked_image_torch[0].cpu())
img_C = to_pil(x_checked_image_torch_C[0].cpu())
return img, img_C, mask_gray
example_path = os.path.join(os.path.dirname(__file__), 'example')
garm_list = os.listdir(os.path.join(example_path,"cloth"))
garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
human_list = os.listdir(os.path.join(example_path,"human"))
human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
human_ex_list = []
for ex_human in human_list_path:
ex_dict= {}
ex_dict['background'] = ex_human
ex_dict['layers'] = None
ex_dict['composite'] = None
human_ex_list.append(ex_dict)
##default human
image_blocks = gr.Blocks().queue()
with image_blocks as demo:
gr.Markdown("## CAT-DM 👕👔👚")
gr.Markdown("Virtual Try-on with your image and garment image")
with gr.Row():
with gr.Column():
imgs = gr.ImageEditor(sources='upload', type="pil", label='Human Picture or use Examples below', interactive=True)
example = gr.Examples(
inputs=imgs,
examples_per_page=10,
examples=human_ex_list
)
with gr.Column():
garm_img = gr.Image(label="Garment", sources='upload', type="pil")
example = gr.Examples(
inputs=garm_img,
examples_per_page=8,
examples=garm_list_path
)
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img",show_download_button=False)
try_button = gr.Button(value="Try-on")
with gr.Column():
image_out_c = gr.Image(label="Output", elem_id="output-img",show_download_button=False)
with gr.Column():
masked_img = gr.Image(label="Masked image output", elem_id="masked_img", show_download_button=False)
try_button.click(fn=start_tryon, inputs=[imgs,garm_img], outputs=[image_out,image_out_c,masked_img], api_name='tryon')
image_blocks.launch()