sky24h commited on
Commit
d7e604f
·
1 Parent(s): 61f61e0

add ZeroGPU support

Browse files
Files changed (1) hide show
  1. inference_utils.py +4 -4
inference_utils.py CHANGED
@@ -63,7 +63,7 @@ def init_pipeline():
63
  id_encoder_path = base_path + "/pytorch_model_1.bin"
64
  pose_encoder_path = base_path + "/pytorch_model_2.bin"
65
 
66
- Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, device=device, subfolder="unet")
67
  id_encoder = ControlNetModel.from_unet(Unet)
68
  pose_encoder = ControlNetModel.from_unet(Unet)
69
  makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", device=device, dtype=torch.float16)
@@ -73,9 +73,9 @@ def init_pipeline():
73
  id_encoder.load_state_dict(id_state_dict, strict=False)
74
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
75
  makeup_encoder.load_state_dict(makeup_state_dict, strict=False)
76
- id_encoder.to(device=device)
77
- pose_encoder.to(device=device)
78
- makeup_encoder.to(device=device)
79
 
80
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
81
  model_id, safety_checker=None, unet=Unet, controlnet=[id_encoder, pose_encoder], device=device, torch_dtype=torch.float16
 
63
  id_encoder_path = base_path + "/pytorch_model_1.bin"
64
  pose_encoder_path = base_path + "/pytorch_model_2.bin"
65
 
66
+ Unet = OriginalUNet2DConditionModel.from_pretrained(model_id, device=device, subfolder="unet").half()
67
  id_encoder = ControlNetModel.from_unet(Unet)
68
  pose_encoder = ControlNetModel.from_unet(Unet)
69
  makeup_encoder = detail_encoder(Unet, "openai/clip-vit-large-patch14", device=device, dtype=torch.float16)
 
73
  id_encoder.load_state_dict(id_state_dict, strict=False)
74
  pose_encoder.load_state_dict(pose_state_dict, strict=False)
75
  makeup_encoder.load_state_dict(makeup_state_dict, strict=False)
76
+ id_encoder.to(device=device).half()
77
+ pose_encoder.to(device=device).half()
78
+ makeup_encoder.to(device=device).half()
79
 
80
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
81
  model_id, safety_checker=None, unet=Unet, controlnet=[id_encoder, pose_encoder], device=device, torch_dtype=torch.float16