radames commited on
Commit
e5edfc8
·
1 Parent(s): 9a8789a

use pruna for quantization

Browse files
server/config.py CHANGED
@@ -20,6 +20,7 @@ class Args(BaseModel):
20
  onediff: bool = False
21
  compel: bool = False
22
  debug: bool = False
 
23
 
24
  def pretty_print(self) -> None:
25
  print("\n")
@@ -123,6 +124,12 @@ parser.add_argument(
123
  default=False,
124
  help="Enable OneDiff",
125
  )
 
 
 
 
 
 
126
  parser.set_defaults(taesd=USE_TAESD)
127
 
128
  config = Args.model_validate(vars(parser.parse_args()))
 
20
  onediff: bool = False
21
  compel: bool = False
22
  debug: bool = False
23
+ pruna: bool = False
24
 
25
  def pretty_print(self) -> None:
26
  print("\n")
 
124
  default=False,
125
  help="Enable OneDiff",
126
  )
127
+ parser.add_argument(
128
+ "--pruna",
129
+ action="store_true",
130
+ default=False,
131
+ help="Enable Pruna",
132
+ )
133
  parser.set_defaults(taesd=USE_TAESD)
134
 
135
  config = Args.model_validate(vars(parser.parse_args()))
server/pipelines/controlnet.py CHANGED
@@ -17,6 +17,8 @@ from config import Args
17
  from pydantic import BaseModel, Field
18
  from PIL import Image
19
  import math
 
 
20
 
21
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
22
  taesd_model = "madebyollin/taesd"
@@ -58,7 +60,7 @@ class Pipeline:
58
  input_mode: str = "image"
59
  page_content: str = page_content
60
 
61
- class InputParams(BaseModel):
62
  prompt: str = Field(
63
  default_prompt,
64
  title="Prompt",
@@ -170,6 +172,13 @@ class Pipeline:
170
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
171
  ).to(device)
172
 
 
 
 
 
 
 
 
173
  if args.sfast:
174
  print("\nRunning sfast compile\n")
175
  from sfast.compilers.stable_diffusion_pipeline_compiler import (
 
17
  from pydantic import BaseModel, Field
18
  from PIL import Image
19
  import math
20
+ from pruna import SmashConfig, smash
21
+ from util import ParamsModel
22
 
23
  base_model = "SimianLuo/LCM_Dreamshaper_v7"
24
  taesd_model = "madebyollin/taesd"
 
60
  input_mode: str = "image"
61
  page_content: str = page_content
62
 
63
+ class InputParams(ParamsModel):
64
  prompt: str = Field(
65
  default_prompt,
66
  title="Prompt",
 
172
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
173
  ).to(device)
174
 
175
+ if args.pruna:
176
+ # Create and smash your model
177
+ smash_config = SmashConfig()
178
+ smash_config["cacher"] = "deepcache"
179
+ smash_config["compiler"] = "stable_fast"
180
+ self.pipe = smash(model=self.pipe, smash_config=smash_config)
181
+
182
  if args.sfast:
183
  print("\nRunning sfast compile\n")
184
  from sfast.compilers.stable_diffusion_pipeline_compiler import (
server/pipelines/img2imgFlux.py CHANGED
@@ -2,21 +2,19 @@ import torch
2
 
3
  from optimum.quanto import freeze, qfloat8, quantize
4
  from transformers.modeling_utils import PreTrainedModel
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- from diffusers import (
7
- FlowMatchEulerDiscreteScheduler,
8
- AutoencoderKL,
9
- AutoencoderTiny,
10
- FluxImg2ImgPipeline,
11
- FluxPipeline,
12
- )
13
-
14
- from diffusers import (
15
- FluxImg2ImgPipeline,
16
- FluxPipeline,
17
- FluxTransformer2DModel,
18
- GGUFQuantizationConfig,
19
- )
20
 
21
  try:
22
  import intel_extension_for_pytorch as ipex # type: ignore
@@ -76,10 +74,10 @@ class Pipeline:
76
  1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
77
  )
78
  width: int = Field(
79
- 256, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
80
  )
81
  height: int = Field(
82
- 256, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
83
  )
84
  strength: float = Field(
85
  0.5,
@@ -107,33 +105,101 @@ class Pipeline:
107
  # "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
108
  # )
109
  print("Loading model")
110
- # ckpt_path: str = "https://huggingface.co/city96/FLUX.1-schnell-gguf/blob/main/flux1-schnell-Q6_K.gguf"
111
- ckpt_path: str = "https://huggingface.co/city96/FLUX.1-schnell-gguf/blob/main/flux1-schnell-Q4_K_S.gguf"
112
- transformer = FluxTransformer2DModel.from_single_file(
113
- ckpt_path,
114
- quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
115
- torch_dtype=torch.bfloat16,
116
- )
117
-
118
- # else:
119
- pipe = FluxImg2ImgPipeline.from_pretrained(
120
- # "black-forest-labs/FLUX.1-dev",
121
- "black-forest-labs/FLUX.1-Schnell",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  transformer=transformer,
123
- torch_dtype=torch.bfloat16,
124
  )
125
- if args.taesd:
126
- pipe.vae = AutoencoderTiny.from_pretrained(
127
- taesd_path, torch_dtype=torch.bfloat16, use_safetensors=True
128
- )
 
129
  # pipe.enable_model_cpu_offload()
130
- pipe = pipe.to(device)
 
 
 
131
 
132
  # pipe.enable_model_cpu_offload()
 
 
 
 
133
 
134
  self.pipe = pipe
135
  self.pipe.set_progress_bar_config(disable=True)
136
-
137
  # vae = AutoencoderKL.from_pretrained(
138
  # base_model_path, subfolder="vae", torch_dtype=torch_dtype
139
  # )
 
2
 
3
  from optimum.quanto import freeze, qfloat8, quantize
4
  from transformers.modeling_utils import PreTrainedModel
5
+ from diffusers import AutoencoderTiny
6
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
7
+ from diffusers.pipelines.flux.pipeline_flux_img2img import FluxImg2ImgPipeline
8
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
9
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
10
+
11
+
12
+ from pruna import smash, SmashConfig
13
+ from pruna.telemetry import set_telemetry_metrics
14
+
15
+ set_telemetry_metrics(False) # disable telemetry for current session
16
+ set_telemetry_metrics(False, set_as_default=True) # disable telemetry globally
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  try:
20
  import intel_extension_for_pytorch as ipex # type: ignore
 
74
  1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
75
  )
76
  width: int = Field(
77
+ 1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
78
  )
79
  height: int = Field(
80
+ 1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
81
  )
82
  strength: float = Field(
83
  0.5,
 
105
  # "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
106
  # )
107
  print("Loading model")
108
+
109
+ model_id = "black-forest-labs/FLUX.1-schnell"
110
+ model_revision = "refs/pr/1"
111
+ text_model_id = "openai/clip-vit-large-patch14"
112
+ model_data_type = torch.bfloat16
113
+ tokenizer = CLIPTokenizer.from_pretrained(
114
+ text_model_id, torch_dtype=model_data_type
115
+ )
116
+ text_encoder = CLIPTextModel.from_pretrained(
117
+ text_model_id, torch_dtype=model_data_type
118
+ )
119
+
120
+ # 2
121
+ tokenizer_2 = T5TokenizerFast.from_pretrained(
122
+ model_id,
123
+ subfolder="tokenizer_2",
124
+ torch_dtype=model_data_type,
125
+ revision=model_revision,
126
+ )
127
+ text_encoder_2 = T5EncoderModel.from_pretrained(
128
+ model_id,
129
+ subfolder="text_encoder_2",
130
+ torch_dtype=model_data_type,
131
+ revision=model_revision,
132
+ )
133
+
134
+ # Transformers
135
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
136
+ model_id, subfolder="scheduler", revision=model_revision
137
+ )
138
+ transformer = FluxTransformer2DModel.from_pretrained(
139
+ model_id,
140
+ subfolder="transformer",
141
+ torch_dtype=model_data_type,
142
+ revision=model_revision,
143
+ )
144
+
145
+ # VAE
146
+ # vae = AutoencoderKL.from_pretrained(
147
+ # model_id,
148
+ # subfolder="vae",
149
+ # torch_dtype=model_data_type,
150
+ # revision=model_revision,
151
+ # )
152
+
153
+ vae = AutoencoderTiny.from_pretrained(
154
+ "madebyollin/taef1", torch_dtype=torch.bfloat16
155
+ )
156
+
157
+ # Initialize the SmashConfig
158
+ smash_config = SmashConfig()
159
+ smash_config["quantizer"] = "quanto"
160
+ smash_config["quanto_calibrate"] = False
161
+ smash_config["quanto_weight_bits"] = "qint4"
162
+ # (
163
+ # "qint4" # "qfloat8" # or "qint2", "qint4", "qint8"
164
+ # )
165
+
166
+ transformer = smash(
167
+ model=transformer,
168
+ smash_config=smash_config,
169
+ )
170
+ text_encoder_2 = smash(
171
+ model=text_encoder_2,
172
+ smash_config=smash_config,
173
+ )
174
+
175
+ pipe = FluxImg2ImgPipeline(
176
+ scheduler=scheduler,
177
+ text_encoder=text_encoder,
178
+ tokenizer=tokenizer,
179
+ text_encoder_2=text_encoder_2,
180
+ tokenizer_2=tokenizer_2,
181
+ vae=vae,
182
  transformer=transformer,
 
183
  )
184
+
185
+ # if args.taesd:
186
+ # pipe.vae = AutoencoderTiny.from_pretrained(
187
+ # taesd_path, torch_dtype=torch.bfloat16, use_safetensors=True
188
+ # )
189
  # pipe.enable_model_cpu_offload()
190
+ pipe.text_encoder.to(device)
191
+ pipe.vae.to(device)
192
+ pipe.transformer.to(device)
193
+ pipe.text_encoder_2.to(device)
194
 
195
  # pipe.enable_model_cpu_offload()
196
+ # For added memory savings run this block, there is however a trade-off with speed.
197
+ # vae.enable_tiling()
198
+ # vae.enable_slicing()
199
+ # pipe.enable_sequential_cpu_offload()
200
 
201
  self.pipe = pipe
202
  self.pipe.set_progress_bar_config(disable=True)
 
203
  # vae = AutoencoderKL.from_pretrained(
204
  # base_model_path, subfolder="vae", torch_dtype=torch_dtype
205
  # )
server/pipelines/img2imgSDTurbo.py CHANGED
@@ -15,6 +15,7 @@ from PIL import Image
15
  from util import ParamsModel
16
  import math
17
 
 
18
 
19
  base_model = "stabilityai/sd-turbo"
20
  taesd_model = "madebyollin/taesd"
@@ -102,6 +103,13 @@ class Pipeline:
102
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
103
  ).to(device)
104
 
 
 
 
 
 
 
 
105
  if args.sfast:
106
  from sfast.compilers.stable_diffusion_pipeline_compiler import (
107
  compile,
@@ -130,8 +138,8 @@ class Pipeline:
130
 
131
  self.pipe.set_progress_bar_config(disable=True)
132
  self.pipe.to(device=device, dtype=torch_dtype)
133
- if device.type != "mps":
134
- self.pipe.unet.to(memory_format=torch.channels_last)
135
 
136
  if args.torch_compile:
137
  print("Running torch compile")
 
15
  from util import ParamsModel
16
  import math
17
 
18
+ from pruna import smash, SmashConfig
19
 
20
  base_model = "stabilityai/sd-turbo"
21
  taesd_model = "madebyollin/taesd"
 
103
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
104
  ).to(device)
105
 
106
+ if args.pruna:
107
+ # Create and smash your model
108
+ smash_config = SmashConfig()
109
+ smash_config["cacher"] = "deepcache"
110
+ smash_config["compiler"] = "stable_fast"
111
+ self.pipe = smash(model=self.pipe, smash_config=smash_config)
112
+
113
  if args.sfast:
114
  from sfast.compilers.stable_diffusion_pipeline_compiler import (
115
  compile,
 
138
 
139
  self.pipe.set_progress_bar_config(disable=True)
140
  self.pipe.to(device=device, dtype=torch_dtype)
141
+ # if device.type != "mps":
142
+ # self.pipe.unet.to(memory_format=torch.channels_last)
143
 
144
  if args.torch_compile:
145
  print("Running torch compile")
server/pipelines/img2imgSDXL-Lightning.py CHANGED
@@ -20,6 +20,7 @@ from pydantic import BaseModel, Field
20
  from PIL import Image
21
  from util import ParamsModel
22
  import math
 
23
 
24
  base = "stabilityai/stable-diffusion-xl-base-1.0"
25
  repo = "ByteDance/SDXL-Lightning"
@@ -135,6 +136,13 @@ class Pipeline:
135
  self.pipe.scheduler.config, timestep_spacing="trailing"
136
  )
137
 
 
 
 
 
 
 
 
138
  if args.sfast:
139
  from sfast.compilers.stable_diffusion_pipeline_compiler import (
140
  compile,
 
20
  from PIL import Image
21
  from util import ParamsModel
22
  import math
23
+ from pruna import SmashConfig, smash
24
 
25
  base = "stabilityai/stable-diffusion-xl-base-1.0"
26
  repo = "ByteDance/SDXL-Lightning"
 
136
  self.pipe.scheduler.config, timestep_spacing="trailing"
137
  )
138
 
139
+ if args.pruna:
140
+ # Create and smash your model
141
+ smash_config = SmashConfig()
142
+ smash_config["cacher"] = "deepcache"
143
+ smash_config["compiler"] = "stable_fast"
144
+ self.pipe = smash(model=self.pipe, smash_config=smash_config)
145
+
146
  if args.sfast:
147
  from sfast.compilers.stable_diffusion_pipeline_compiler import (
148
  compile,
server/pipelines/img2imgSDXLTurbo.py CHANGED
@@ -17,6 +17,13 @@ from PIL import Image
17
  from util import ParamsModel
18
  import math
19
 
 
 
 
 
 
 
 
20
  base_model = "stabilityai/sdxl-turbo"
21
  taesd_model = "madebyollin/taesdxl"
22
 
@@ -104,10 +111,11 @@ class Pipeline:
104
  )
105
 
106
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
107
- self.pipe = AutoPipelineForImage2Image.from_pretrained(
108
  base_model,
109
  safety_checker=None,
110
  )
 
111
  if args.taesd:
112
  self.pipe.vae = AutoencoderTiny.from_pretrained(
113
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
@@ -125,11 +133,16 @@ class Pipeline:
125
  config.enable_cuda_graph = True
126
  self.pipe = compile(self.pipe, config=config)
127
 
128
- self.pipe.set_progress_bar_config(disable=True)
129
- self.pipe.to(device=device, dtype=torch_dtype)
130
  if device.type != "mps":
131
  self.pipe.unet.to(memory_format=torch.channels_last)
132
 
 
 
 
 
 
 
 
133
  if args.torch_compile:
134
  print("Running torch compile")
135
  self.pipe.unet = torch.compile(
@@ -151,6 +164,9 @@ class Pipeline:
151
  requires_pooled=[False, True],
152
  )
153
 
 
 
 
154
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
155
  generator = torch.manual_seed(params.seed)
156
  prompt = params.prompt
 
17
  from util import ParamsModel
18
  import math
19
 
20
+ from pruna import smash, SmashConfig
21
+ from pruna.telemetry import set_telemetry_metrics
22
+
23
+ set_telemetry_metrics(False) # disable telemetry for current session
24
+ set_telemetry_metrics(False, set_as_default=True) # disable telemetry globally
25
+
26
+
27
  base_model = "stabilityai/sdxl-turbo"
28
  taesd_model = "madebyollin/taesdxl"
29
 
 
111
  )
112
 
113
  def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
114
+ base_pipe = AutoPipelineForImage2Image.from_pretrained(
115
  base_model,
116
  safety_checker=None,
117
  )
118
+ self.pipe = None
119
  if args.taesd:
120
  self.pipe.vae = AutoencoderTiny.from_pretrained(
121
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
 
133
  config.enable_cuda_graph = True
134
  self.pipe = compile(self.pipe, config=config)
135
 
 
 
136
  if device.type != "mps":
137
  self.pipe.unet.to(memory_format=torch.channels_last)
138
 
139
+ if args.pruna:
140
+ # Create and smash your model
141
+ smash_config = SmashConfig()
142
+ smash_config["cacher"] = "deepcache"
143
+ smash_config["compiler"] = "stable_fast"
144
+ self.pipe = smash(model=base_pipe, smash_config=smash_config)
145
+
146
  if args.torch_compile:
147
  print("Running torch compile")
148
  self.pipe.unet = torch.compile(
 
164
  requires_pooled=[False, True],
165
  )
166
 
167
+ self.pipe.set_progress_bar_config(disable=True)
168
+ self.pipe.to(device=device, dtype=torch_dtype)
169
+
170
  def predict(self, params: "Pipeline.InputParams") -> Image.Image:
171
  generator = torch.manual_seed(params.seed)
172
  prompt = params.prompt
server/requirements.txt CHANGED
@@ -1,30 +1,35 @@
1
- diffusers==0.32.0
2
- transformers==4.47.1
 
 
 
 
 
 
 
 
 
 
3
  huggingface-hub
4
  hf_transfer
5
- --extra-index-url https://download.pytorch.org/whl/cu121;
6
- torch==2.3.0
7
- fastapi==0.115.6
8
- uvicorn[standard]==0.34.0
9
  Pillow==11.0.0
10
- accelerate==1.2.1
11
  compel==2.0.2
12
  controlnet-aux==0.0.9
13
  peft==0.14.0
14
- xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
15
  markdown2
16
  safetensors
17
- stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/nightly/stable_fast-1.0.5.dev20241127+torch230cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
  #oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20241114%2Bcu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
19
  #onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
20
  setuptools
21
  mpmath==1.3.0
22
- numpy==1.*
23
  controlnet-aux
24
  sentencepiece==0.2.0
25
- optimum-quanto
26
  gguf==0.13.0
27
- pydantic>=2.7.0
28
  types-Pillow
29
  mypy
30
- python-dotenv
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.5.1
3
+ torchvision
4
+ torchaudio
5
+ xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
6
+ numpy
7
+ diffusers
8
+ llvmlite>=0.39.0
9
+ numba>=0.56.0
10
+ pruna[stable-fast] ; sys_platform != 'darwin' or platform_machine != 'arm64'
11
+ transformers
12
+ pydantic
13
  huggingface-hub
14
  hf_transfer
15
+ fastapi
16
+ uvicorn[standard]
 
 
17
  Pillow==11.0.0
18
+ accelerate
19
  compel==2.0.2
20
  controlnet-aux==0.0.9
21
  peft==0.14.0
 
22
  markdown2
23
  safetensors
24
+ # stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/nightly/stable_fast-1.0.5.dev20241127+torch230cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
25
  #oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20241114%2Bcu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
26
  #onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
27
  setuptools
28
  mpmath==1.3.0
 
29
  controlnet-aux
30
  sentencepiece==0.2.0
31
+ optimum-quanto # has to be optimum-quanto==0.2.5 for pruna int4
32
  gguf==0.13.0
 
33
  types-Pillow
34
  mypy
35
+ python-dotenv