TobDeBer commited on
Commit
d8db4fe
·
verified ·
1 Parent(s): afd93fd

use RealVisXL_V5.0_Lightning

Browse files
Files changed (1) hide show
  1. pulid/pipeline.py +21 -14
pulid/pipeline.py CHANGED
@@ -8,6 +8,7 @@ from basicsr.utils import img2tensor, tensor2img
8
  from diffusers import (
9
  DPMSolverMultistepScheduler,
10
  StableDiffusionXLPipeline,
 
11
  UNet2DConditionModel,
12
  )
13
  from facexlib.parsing import init_parsing_model
@@ -34,22 +35,28 @@ class PuLIDPipeline:
34
  def __init__(self, *args, **kwargs):
35
  super().__init__()
36
  self.device = 'cuda'
37
- sdxl_base_repo = 'stabilityai/stable-diffusion-xl-base-1.0'
38
- sdxl_lightning_repo = 'ByteDance/SDXL-Lightning'
39
- self.sdxl_base_repo = sdxl_base_repo
40
 
41
  # load base model
42
- unet = UNet2DConditionModel.from_config(sdxl_base_repo, subfolder='unet').to(self.device, torch.float16)
43
- unet.load_state_dict(
44
- load_file(
45
- hf_hub_download(sdxl_lightning_repo, 'sdxl_lightning_4step_unet.safetensors'), device=self.device
46
- )
47
- )
48
- unet.half()
49
- self.hack_unet_attn_layers(unet)
50
- self.pipe = StableDiffusionXLPipeline.from_pretrained(
51
- sdxl_base_repo, unet=unet, torch_dtype=torch.float16, variant="fp16"
52
- ).to(self.device)
 
 
 
 
 
 
53
  self.pipe.watermark = None
54
 
55
  # scheduler
 
8
  from diffusers import (
9
  DPMSolverMultistepScheduler,
10
  StableDiffusionXLPipeline,
11
+ AutoPipelineForText2Image,
12
  UNet2DConditionModel,
13
  )
14
  from facexlib.parsing import init_parsing_model
 
35
  def __init__(self, *args, **kwargs):
36
  super().__init__()
37
  self.device = 'cuda'
38
+ #sdxl_base_repo = 'stabilityai/stable-diffusion-xl-base-1.0'
39
+ #sdxl_lightning_repo = 'ByteDance/SDXL-Lightning'
40
+ #self.sdxl_base_repo = sdxl_base_repo
41
 
42
  # load base model
43
+ #unet = UNet2DConditionModel.from_config(sdxl_base_repo, subfolder='unet').to(self.device, torch.float16)
44
+ #unet.load_state_dict(
45
+ # load_file(
46
+ # hf_hub_download(sdxl_lightning_repo, 'sdxl_lightning_4step_unet.safetensors'), device=self.device
47
+ # )
48
+ #)
49
+ #unet.half()
50
+ #self.hack_unet_attn_layers(unet)
51
+ #self.pipe = StableDiffusionXLPipeline.from_pretrained(
52
+ # sdxl_base_repo, unet=unet, torch_dtype=torch.float16, variant="fp16"
53
+ #).to(self.device)
54
+
55
+ # SG161222/RealVisXL_V5.0_Lightning, lykon/dreamshaper-xl-lightning
56
+ self.pipe = AutoPipelineForText2Image.from_pretrained('SG161222/RealVisXL_V5.0_Lightning', torch_dtype=torch.float16, variant="fp16")
57
+ self.hack_unet_attn_layers(self.pipe.unet)
58
+ pipe = pipe.to(self.device)
59
+
60
  self.pipe.watermark = None
61
 
62
  # scheduler