add pruna
Browse files
server/pipelines/txt2img.py
CHANGED
@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
|
|
12 |
from util import ParamsModel
|
13 |
from PIL import Image
|
14 |
from typing import List
|
|
|
15 |
|
16 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
17 |
taesd_model = "madebyollin/taesd"
|
@@ -86,6 +87,13 @@ class Pipeline:
|
|
86 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
87 |
).to(device)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
if args.sfast:
|
90 |
from sfast.compilers.stable_diffusion_pipeline_compiler import (
|
91 |
compile,
|
|
|
12 |
from util import ParamsModel
|
13 |
from PIL import Image
|
14 |
from typing import List
|
15 |
+
from pruna import SmashConfig, smash
|
16 |
|
17 |
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
18 |
taesd_model = "madebyollin/taesd"
|
|
|
87 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
88 |
).to(device)
|
89 |
|
90 |
+
if args.pruna:
|
91 |
+
# Create and smash your model
|
92 |
+
smash_config = SmashConfig()
|
93 |
+
# smash_config["cacher"] = "deepcache"
|
94 |
+
smash_config["compiler"] = "stable_fast"
|
95 |
+
self.pipe = smash(model=self.pipe, smash_config=smash_config)
|
96 |
+
|
97 |
if args.sfast:
|
98 |
from sfast.compilers.stable_diffusion_pipeline_compiler import (
|
99 |
compile,
|
server/pipelines/txt2imgLora.py
CHANGED
@@ -12,6 +12,7 @@ from config import Args
|
|
12 |
from pydantic import BaseModel, Field
|
13 |
from util import ParamsModel
|
14 |
from PIL import Image
|
|
|
15 |
|
16 |
base_model = "wavymulder/Analog-Diffusion"
|
17 |
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
@@ -93,6 +94,13 @@ class Pipeline:
|
|
93 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
94 |
).to(device)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
97 |
self.pipe.set_progress_bar_config(disable=True)
|
98 |
self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
|
|
|
12 |
from pydantic import BaseModel, Field
|
13 |
from util import ParamsModel
|
14 |
from PIL import Image
|
15 |
+
from pruna import SmashConfig, smash
|
16 |
|
17 |
base_model = "wavymulder/Analog-Diffusion"
|
18 |
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
|
|
94 |
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
95 |
).to(device)
|
96 |
|
97 |
+
if args.pruna:
|
98 |
+
# Create and smash your model
|
99 |
+
smash_config = SmashConfig()
|
100 |
+
# smash_config["cacher"] = "deepcache"
|
101 |
+
smash_config["compiler"] = "stable_fast"
|
102 |
+
self.pipe = smash(model=self.pipe, smash_config=smash_config)
|
103 |
+
|
104 |
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
105 |
self.pipe.set_progress_bar_config(disable=True)
|
106 |
self.pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
|