use RealVisXL_V5.0_Lightning
Browse files- pulid/pipeline.py +21 -14
pulid/pipeline.py
CHANGED
@@ -8,6 +8,7 @@ from basicsr.utils import img2tensor, tensor2img
|
|
8 |
from diffusers import (
|
9 |
DPMSolverMultistepScheduler,
|
10 |
StableDiffusionXLPipeline,
|
|
|
11 |
UNet2DConditionModel,
|
12 |
)
|
13 |
from facexlib.parsing import init_parsing_model
|
@@ -34,22 +35,28 @@ class PuLIDPipeline:
|
|
34 |
def __init__(self, *args, **kwargs):
|
35 |
super().__init__()
|
36 |
self.device = 'cuda'
|
37 |
-
sdxl_base_repo = 'stabilityai/stable-diffusion-xl-base-1.0'
|
38 |
-
sdxl_lightning_repo = 'ByteDance/SDXL-Lightning'
|
39 |
-
self.sdxl_base_repo = sdxl_base_repo
|
40 |
|
41 |
# load base model
|
42 |
-
unet = UNet2DConditionModel.from_config(sdxl_base_repo, subfolder='unet').to(self.device, torch.float16)
|
43 |
-
unet.load_state_dict(
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
)
|
48 |
-
unet.half()
|
49 |
-
self.hack_unet_attn_layers(unet)
|
50 |
-
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
51 |
-
|
52 |
-
).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
self.pipe.watermark = None
|
54 |
|
55 |
# scheduler
|
|
|
8 |
from diffusers import (
|
9 |
DPMSolverMultistepScheduler,
|
10 |
StableDiffusionXLPipeline,
|
11 |
+
AutoPipelineForText2Image,
|
12 |
UNet2DConditionModel,
|
13 |
)
|
14 |
from facexlib.parsing import init_parsing_model
|
|
|
35 |
def __init__(self, *args, **kwargs):
|
36 |
super().__init__()
|
37 |
self.device = 'cuda'
|
38 |
+
#sdxl_base_repo = 'stabilityai/stable-diffusion-xl-base-1.0'
|
39 |
+
#sdxl_lightning_repo = 'ByteDance/SDXL-Lightning'
|
40 |
+
#self.sdxl_base_repo = sdxl_base_repo
|
41 |
|
42 |
# load base model
|
43 |
+
#unet = UNet2DConditionModel.from_config(sdxl_base_repo, subfolder='unet').to(self.device, torch.float16)
|
44 |
+
#unet.load_state_dict(
|
45 |
+
# load_file(
|
46 |
+
# hf_hub_download(sdxl_lightning_repo, 'sdxl_lightning_4step_unet.safetensors'), device=self.device
|
47 |
+
# )
|
48 |
+
#)
|
49 |
+
#unet.half()
|
50 |
+
#self.hack_unet_attn_layers(unet)
|
51 |
+
#self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
52 |
+
# sdxl_base_repo, unet=unet, torch_dtype=torch.float16, variant="fp16"
|
53 |
+
#).to(self.device)
|
54 |
+
|
55 |
+
# SG161222/RealVisXL_V5.0_Lightning, lykon/dreamshaper-xl-lightning
|
56 |
+
self.pipe = AutoPipelineForText2Image.from_pretrained('SG161222/RealVisXL_V5.0_Lightning', torch_dtype=torch.float16, variant="fp16")
|
57 |
+
self.hack_unet_attn_layers(self.pipe.unet)
|
58 |
+
pipe = pipe.to(self.device)
|
59 |
+
|
60 |
self.pipe.watermark = None
|
61 |
|
62 |
# scheduler
|