|
from diffusers import StableDiffusionInpaintPipeline |
|
import os |
|
|
|
from tqdm import tqdm |
|
from PIL import Image |
|
import numpy as np |
|
import cv2 |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
|
|
from data.base_dataset import Normalize_image |
|
from utils.saving_utils import load_checkpoint_mgpu |
|
from networks import U2NET |
|
import argparse |
|
from enum import Enum |
|
from rembg import remove |
|
|
|
class Parts: |
|
UPPER = 1 |
|
LOWER = 2 |
|
|
|
def parse_arguments(): |
|
parser = argparse.ArgumentParser( |
|
description="Stable Fashion API, allows you to picture yourself in any cloth your imagination can think of!" |
|
) |
|
parser.add_argument('--image', type=str, required=True, help='path to image') |
|
parser.add_argument('--part', choices=['upper', 'lower'], default='upper', type=str) |
|
parser.add_argument('--resolution', choices=[256, 512, 1024, 2048], default=256, type=int) |
|
parser.add_argument('--prompt', type=str, default="A pink cloth") |
|
parser.add_argument('--num_steps', type=int, default=5) |
|
parser.add_argument('--guidance_scale', type=float, default=7.5) |
|
parser.add_argument('--rembg', action='store_true') |
|
parser.add_argument('--output', default='output.jpg', type=str) |
|
args, _ = parser.parse_known_args() |
|
return args |
|
|
|
|
|
def load_u2net(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth") |
|
net = U2NET(in_ch=3, out_ch=4) |
|
net = load_checkpoint_mgpu(net, checkpoint_path) |
|
net = net.to(device) |
|
net = net.eval() |
|
return net |
|
|
|
def change_bg_color(rgba_image, color): |
|
new_image = Image.new("RGBA", rgba_image.size, color) |
|
new_image.paste(rgba_image, (0, 0), rgba_image) |
|
return new_image.convert("RGB") |
|
|
|
|
|
def load_inpainting_pipeline(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
revision="fp16", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
).to(device) |
|
return inpainting_pipeline |
|
def process_image(args, inpainting_pipeline, net): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
image_path = args.image |
|
transforms_list = [] |
|
transforms_list += [transforms.ToTensor()] |
|
transforms_list += [Normalize_image(0.5, 0.5)] |
|
transform_rgb = transforms.Compose(transforms_list) |
|
img = Image.open(image_path) |
|
img = img.convert("RGB") |
|
img = img.resize((args.resolution, args.resolution)) |
|
if args.rembg: |
|
img_with_green_bg = remove(img) |
|
img_with_green_bg = change_bg_color(img_with_green_bg, color="GREEN") |
|
img_with_green_bg = img_with_green_bg.convert("RGB") |
|
else: |
|
img_with_green_bg = img |
|
image_tensor = transform_rgb(img_with_green_bg) |
|
image_tensor = image_tensor.unsqueeze(0) |
|
output_tensor = net(image_tensor.to(device)) |
|
output_tensor = F.log_softmax(output_tensor[0], dim=1) |
|
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] |
|
output_tensor = torch.squeeze(output_tensor, dim=0) |
|
output_tensor = torch.squeeze(output_tensor, dim=0) |
|
output_arr = output_tensor.cpu().numpy() |
|
mask_code = eval(f"Parts.{args.part.upper()}") |
|
mask = (output_arr == mask_code) |
|
output_arr[mask] = 1 |
|
output_arr[~mask] = 0 |
|
output_arr *= 255 |
|
mask_PIL = Image.fromarray(output_arr.astype("uint8"), mode="L") |
|
clothed_image_from_pipeline = inpainting_pipeline(prompt=args.prompt, |
|
image=img_with_green_bg, |
|
mask_image=mask_PIL, |
|
width=args.resolution, |
|
height=args.resolution, |
|
guidance_scale=args.guidance_scale, |
|
num_inference_steps=args.num_steps).images[0] |
|
clothed_image_from_pipeline = remove(clothed_image_from_pipeline) |
|
clothed_image_from_pipeline = change_bg_color(clothed_image_from_pipeline, "WHITE") |
|
return clothed_image_from_pipeline.convert("RGB") |
|
if __name__ == '__main__': |
|
args = parse_arguments() |
|
net = load_u2net() |
|
inpainting_pipeline = load_inpainting_pipeline() |
|
result_image = process_image(args, inpainting_pipeline, net) |
|
result_image.save(args.output) |
|
|