jokerbit commited on
Commit
560e4bf
·
verified ·
1 Parent(s): e6fe518
Files changed (1) hide show
  1. src/pipeline.py +20 -11
src/pipeline.py CHANGED
@@ -9,16 +9,17 @@ from huggingface_hub.constants import HF_HUB_CACHE
9
  from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
- from transformers import T5EncoderModel, CLIPTextModel
13
 
 
 
 
 
14
  torch._inductor.config.conv_1x1_as_mm = True
15
  torch._inductor.config.coordinate_descent_tuning = True
16
  torch._inductor.config.epilogue_fusion = False
17
  torch._inductor.config.coordinate_descent_check_all_directions = True
18
-
19
- Pipeline: TypeAlias = FluxPipeline
20
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
21
- torch.backends.cudnn.benchmark = True
22
 
23
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
24
  REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"
@@ -34,22 +35,29 @@ def load_pipeline() -> Pipeline:
34
  use_safetensors=False,
35
  local_files_only=True,
36
  torch_dtype=torch.bfloat16)
37
-
 
 
 
 
38
  pipeline = FluxPipeline.from_pretrained(
39
  CHECKPOINT,
40
  revision=REVISION,
41
  transformer=transformer,
 
42
  local_files_only=True,
43
  torch_dtype=torch.bfloat16,
44
  )
 
 
 
 
 
45
  pipeline.to("cuda")
46
- # pipeline.transformer.to(memory_format=torch.channels_last)
47
- pipeline.transformer = torch.compile(pipeline.transformer)
48
- # pipeline.text_encoder.fuse_qkv_projections()
49
- # pipeline.vae = torch.compile(pipeline.vae)
50
- for _ in range(4):
51
  pipeline("cat", num_inference_steps=4)
52
- torch.cuda.empty_cache()
53
  return pipeline
54
 
55
  @torch.inference_mode()
@@ -82,3 +90,4 @@ if __name__ == "__main__":
82
  infer(request, pipe_)
83
  stop_time = perf_counter()
84
  print(f"Request in {stop_time - start_time}s")
 
 
9
  from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
+ from transformers import T5EncoderModel, CLIPTextModel, logging
13
 
14
+
15
+ Pipeline: TypeAlias = FluxPipeline
16
+ torch.backends.cudnn.benchmark = True
17
+ torch.backends.cudnn.benchmark = True
18
  torch._inductor.config.conv_1x1_as_mm = True
19
  torch._inductor.config.coordinate_descent_tuning = True
20
  torch._inductor.config.epilogue_fusion = False
21
  torch._inductor.config.coordinate_descent_check_all_directions = True
 
 
22
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
 
23
 
24
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
25
  REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"
 
35
  use_safetensors=False,
36
  local_files_only=True,
37
  torch_dtype=torch.bfloat16)
38
+ vae = AutoencoderTiny.from_pretrained(
39
+ TinyVAE,
40
+ revision=TinyVAE_REV,
41
+ local_files_only=True,
42
+ torch_dtype=torch.bfloat16)
43
  pipeline = FluxPipeline.from_pretrained(
44
  CHECKPOINT,
45
  revision=REVISION,
46
  transformer=transformer,
47
+ # vae=vae,
48
  local_files_only=True,
49
  torch_dtype=torch.bfloat16,
50
  )
51
+
52
+ pipeline.transformer.to(memory_format=torch.channels_last)
53
+ # quantize_(pipeline.vae, int8_weight_only())
54
+ pipeline.vae.to(memory_format=torch.channels_last)
55
+ pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
56
  pipeline.to("cuda")
57
+
58
+ for _ in range(2):
 
 
 
59
  pipeline("cat", num_inference_steps=4)
60
+
61
  return pipeline
62
 
63
  @torch.inference_mode()
 
90
  infer(request, pipe_)
91
  stop_time = perf_counter()
92
  print(f"Request in {stop_time - start_time}s")
93
+