Stable-Makeup-unofficial / inference_utils.py
sky24h's picture
add ZeroGPU support
d7e604f
import os
import cv2
import torch
import random
import numpy as np
seed = 1024
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from PIL import Image
from gdown import download_folder
from spiga_draw import spiga_process, spiga_segmentation
from pipeline_sd15 import StableDiffusionControlNetPipeline
from diffusers import DDIMScheduler, ControlNetModel
from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
from detail_encoder.encoder_plus import detail_encoder
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
def get_draw(pil_img, size):
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
spigas = spiga_process(cv2_img)
if spigas == False:
width, height = pil_img.size
black_image_pil = Image.new("RGB", (width, height), color=(0, 0, 0))
return black_image_pil
else:
spigas_faces = spiga_segmentation(spigas, size=size)
return spigas_faces
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"])
def concatenate_images(image_files, output_file):
images = image_files # list
max_height = max(img.height for img in images)
images = [img.resize((img.width, max_height)) for img in images]
total_width = sum(img.width for img in images)
combined = Image.new("RGB", (total_width, max_height))
x_offset = 0
for img in images:
combined.paste(img, (x_offset, 0))
x_offset += img.width
combined.save(output_file)
def init_pipeline():
# Initialize the model
model_id = "runwayml/stable-diffusion-v1-5" # or your local sdv1-5 path
base_path = "./checkpoints/stablemakeup"
folder_id = "1397t27GrUyLPnj17qVpKWGwg93EcaFfg"
if not os.path.exists(base_path):
download_folder(id=folder_id, output=base_path, quiet=False, use_cookies=False)
makeup_encoder_path = base_path + "/pytorch_model.bin"
id_encoder_path = base_path + "/pytorch_model_1.bin"
pose_encoder_path = base_path + "/pytorch_model_2.bin"
Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, device=device, subfolder="unet").half()
id_encoder = ControlNetModel.from_unet(Unet)
pose_encoder = ControlNetModel.from_unet(Unet)
makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", device=device, dtype=torch.float16)
id_state_dict = torch.load(id_encoder_path, map_location=torch.device('cpu'))
pose_state_dict = torch.load(pose_encoder_path, map_location=torch.device('cpu'))
makeup_state_dict = torch.load(makeup_encoder_path, map_location=torch.device('cpu'))
id_encoder.load_state_dict(id_state_dict, strict=False)
pose_encoder.load_state_dict(pose_state_dict, strict=False)
makeup_encoder.load_state_dict(makeup_state_dict, strict=False)
id_encoder.to(device=device).half()
pose_encoder.to(device=device).half()
makeup_encoder.to(device=device).half()
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, safety_checker=None, unet=Unet, controlnet=[id_encoder, pose_encoder], device=device, torch_dtype=torch.float16
).to(device=device)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
return pipe, makeup_encoder
# Initialize the model
pipeline, makeup_encoder = init_pipeline()
def inference(id_image_pil, makeup_image_pil, guidance_scale=1.6, size=512):
id_image = id_image_pil.resize((size, size))
makeup_image = makeup_image_pil.resize((size, size))
pose_image = get_draw(id_image, size=size)
result_img = makeup_encoder.generate(id_image=[id_image, pose_image], makeup_image=makeup_image, pipe=pipeline, guidance_scale=guidance_scale)
return result_img