sky24h commited on
Commit
1a1a5fe
·
1 Parent(s): 1d24639
Files changed (2) hide show
  1. inference_utils.py +27 -14
  2. spiga_draw.py +0 -11
inference_utils.py CHANGED
@@ -9,6 +9,17 @@ torch.cuda.manual_seed_all(seed)
9
  torch.backends.cudnn.deterministic = True
10
  torch.backends.cudnn.benchmark = False
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  from PIL import Image
13
  from gdown import download_folder
14
  from facelib import FaceDetector
@@ -53,21 +64,21 @@ def concatenate_images(image_files, output_file):
53
 
54
  def init_pipeline():
55
  # Initialize the model
56
- model_id = "runwayml/stable-diffusion-v1-5" # or your local sdv1-5 path
57
  base_path = "./checkpoints/stablemakeup"
58
  folder_id = "1397t27GrUyLPnj17qVpKWGwg93EcaFfg"
59
  if not os.path.exists(base_path):
60
  download_folder(id=folder_id, output=base_path, quiet=False, use_cookies=False)
61
  makeup_encoder_path = base_path + "/pytorch_model.bin"
62
- id_encoder_path = base_path + "/pytorch_model_1.bin"
63
- pose_encoder_path = base_path + "/pytorch_model_2.bin"
64
-
65
- Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
66
- id_encoder = ControlNetModel.from_unet(Unet)
67
- pose_encoder = ControlNetModel.from_unet(Unet)
68
- makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", "cuda", dtype=torch.float32)
69
- id_state_dict = torch.load(id_encoder_path)
70
- pose_state_dict = torch.load(pose_encoder_path)
71
  makeup_state_dict = torch.load(makeup_encoder_path)
72
  id_encoder.load_state_dict(id_state_dict, strict=False)
73
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
@@ -82,14 +93,16 @@ def init_pipeline():
82
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
83
  return pipe, makeup_encoder
84
 
 
85
  # Initialize the model
86
  pipeline, makeup_encoder = init_pipeline()
87
 
88
 
89
  def inference(id_image_pil, makeup_image_pil, guidance_scale=1.6, size=512):
90
- id_image = id_image_pil.resize((size, size))
91
  makeup_image = makeup_image_pil.resize((size, size))
92
- pose_image = get_draw(id_image, size=size)
93
- result_img = makeup_encoder.generate(id_image=[id_image, pose_image], makeup_image=makeup_image, pipe=pipeline, guidance_scale=guidance_scale)
 
 
94
  return result_img
95
-
 
9
  torch.backends.cudnn.deterministic = True
10
  torch.backends.cudnn.benchmark = False
11
 
12
+ # SPIGA ckpt downloading always fails, so we download it manually and put it in the right place.
13
+ import site
14
+ from gdown import download
15
+
16
+ user_site_packages_path = site.getusersitepackages()
17
+ spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
18
+ ckpt_path = os.path.join(user_site_packages_path, "spiga/models/weights/spiga_300wpublic.pt")
19
+ if not os.path.exists(ckpt_path):
20
+ os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
21
+ download(id=spiga_file_id, output=ckpt_path)
22
+
23
  from PIL import Image
24
  from gdown import download_folder
25
  from facelib import FaceDetector
 
64
 
65
  def init_pipeline():
66
  # Initialize the model
67
+ model_id = "runwayml/stable-diffusion-v1-5" # or your local sdv1-5 path
68
  base_path = "./checkpoints/stablemakeup"
69
  folder_id = "1397t27GrUyLPnj17qVpKWGwg93EcaFfg"
70
  if not os.path.exists(base_path):
71
  download_folder(id=folder_id, output=base_path, quiet=False, use_cookies=False)
72
  makeup_encoder_path = base_path + "/pytorch_model.bin"
73
+ id_encoder_path = base_path + "/pytorch_model_1.bin"
74
+ pose_encoder_path = base_path + "/pytorch_model_2.bin"
75
+
76
+ Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to("cuda")
77
+ id_encoder = ControlNetModel.from_unet(Unet)
78
+ pose_encoder = ControlNetModel.from_unet(Unet)
79
+ makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", "cuda", dtype=torch.float32)
80
+ id_state_dict = torch.load(id_encoder_path)
81
+ pose_state_dict = torch.load(pose_encoder_path)
82
  makeup_state_dict = torch.load(makeup_encoder_path)
83
  id_encoder.load_state_dict(id_state_dict, strict=False)
84
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
 
93
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
94
  return pipe, makeup_encoder
95
 
96
+
97
  # Initialize the model
98
  pipeline, makeup_encoder = init_pipeline()
99
 
100
 
101
  def inference(id_image_pil, makeup_image_pil, guidance_scale=1.6, size=512):
102
+ id_image = id_image_pil.resize((size, size))
103
  makeup_image = makeup_image_pil.resize((size, size))
104
+ pose_image = get_draw(id_image, size=size)
105
+ result_img = makeup_encoder.generate(
106
+ id_image=[id_image, pose_image], makeup_image=makeup_image, pipe=pipeline, guidance_scale=guidance_scale
107
+ )
108
  return result_img
 
spiga_draw.py CHANGED
@@ -7,17 +7,6 @@ from facelib import FaceDetector
7
  from spiga.inference.config import ModelConfig
8
  from spiga.inference.framework import SPIGAFramework
9
 
10
-
11
-
12
- # SPIGA ckpt downloading always fails, so we download it manually and put it in the right place.
13
- import site
14
- from gdown import download
15
- user_site_packages_path = site.getusersitepackages()
16
- spiga_file_id = "1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC"
17
- ckpt_path = os.path.join(user_site_packages_path, "spiga/models/weights/spiga_300wpublic.pt")
18
- if not os.path.exists(ckpt_path):
19
- os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
20
- download(id=spiga_file_id, output=ckpt_path)
21
  processor = SPIGAFramework(ModelConfig("300wpublic"))
22
 
23
  def center_crop(image, size):
 
7
  from spiga.inference.config import ModelConfig
8
  from spiga.inference.framework import SPIGAFramework
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  processor = SPIGAFramework(ModelConfig("300wpublic"))
11
 
12
  def center_crop(image, size):