jokerbit commited on
Commit
7b00150
·
verified ·
1 Parent(s): f6baeb3

Upload src/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipeline.py +29 -13
src/pipeline.py CHANGED
@@ -4,18 +4,26 @@ from typing import TypeAlias
4
 
5
  import torch
6
  from PIL.Image import Image
7
- from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
  from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
- from transformers import T5EncoderModel, CLIPTextModel
13
-
 
14
 
15
  Pipeline: TypeAlias = FluxPipeline
16
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
17
  torch.backends.cudnn.benchmark = True
 
 
 
 
18
 
 
 
 
19
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
20
  REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"
21
 
@@ -24,12 +32,12 @@ TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
24
 
25
 
26
  def load_pipeline() -> Pipeline:
27
- path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
28
  transformer = FluxTransformer2DModel.from_pretrained(
29
  path,
30
  use_safetensors=False,
31
  local_files_only=True,
32
- torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
33
 
34
  pipeline = FluxPipeline.from_pretrained(
35
  CHECKPOINT,
@@ -37,13 +45,19 @@ def load_pipeline() -> Pipeline:
37
  transformer=transformer,
38
  local_files_only=True,
39
  torch_dtype=torch.bfloat16,
40
- )
41
- pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
42
- pipeline.to("cuda")
43
-
44
- for _ in range(4):
45
- pipeline("cat", num_inference_steps=4)
46
 
 
 
 
 
 
 
 
 
 
 
 
47
  return pipeline
48
 
49
  @torch.inference_mode()
@@ -67,12 +81,14 @@ if __name__ == "__main__":
67
  height=None,
68
  width=None,
69
  seed=666)
 
70
  start_time = perf_counter()
71
  pipe_ = load_pipeline()
72
  stop_time = perf_counter()
73
  print(f"Pipeline is loaded in {stop_time - start_time}s")
74
  for _ in range(4):
75
  start_time = perf_counter()
76
- infer(request, pipe_)
77
  stop_time = perf_counter()
78
  print(f"Request in {stop_time - start_time}s")
 
 
4
 
5
  import torch
6
  from PIL.Image import Image
7
+ from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, AutoencoderTiny, DiffusionPipeline
8
  from huggingface_hub.constants import HF_HUB_CACHE
9
  from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
+ from transformers import T5EncoderModel, CLIPTextModel, logging
13
+ import torch._dynamo
14
+ torch._dynamo.config.suppress_errors = True
15
 
16
  Pipeline: TypeAlias = FluxPipeline
17
+
18
  torch.backends.cudnn.benchmark = True
19
+ torch._inductor.config.conv_1x1_as_mm = True
20
+ torch._inductor.config.coordinate_descent_tuning = True
21
+ torch._inductor.config.epilogue_fusion = False
22
+ torch._inductor.config.coordinate_descent_check_all_directions = True
23
 
24
+
25
+ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "True"
27
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
28
  REVISION = "5ef0012f11a863e5111ec56540302a023bc8587b"
29
 
 
32
 
33
 
34
  def load_pipeline() -> Pipeline:
35
+ path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
36
  transformer = FluxTransformer2DModel.from_pretrained(
37
  path,
38
  use_safetensors=False,
39
  local_files_only=True,
40
+ torch_dtype=torch.bfloat16)
41
 
42
  pipeline = FluxPipeline.from_pretrained(
43
  CHECKPOINT,
 
45
  transformer=transformer,
46
  local_files_only=True,
47
  torch_dtype=torch.bfloat16,
48
+ ).to("cuda")
 
 
 
 
 
49
 
50
+ pipeline.transformer.to(memory_format=torch.channels_last)
51
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=False)
52
+ pipeline.vae.to(memory_format=torch.channels_last)
53
+ quantize_(pipeline.vae, int8_weight_only())
54
+ pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, mode="max-autotune")
55
+
56
+ PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
57
+ with torch.inference_mode():
58
+ for _ in range(4):
59
+ pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
60
+ torch.cuda.empty_cache()
61
  return pipeline
62
 
63
  @torch.inference_mode()
 
81
  height=None,
82
  width=None,
83
  seed=666)
84
+ generator = torch.Generator(device="cuda")
85
  start_time = perf_counter()
86
  pipe_ = load_pipeline()
87
  stop_time = perf_counter()
88
  print(f"Pipeline is loaded in {stop_time - start_time}s")
89
  for _ in range(4):
90
  start_time = perf_counter()
91
+ infer(request, pipe_, generator=generator.manual_seed(request.seed))
92
  stop_time = perf_counter()
93
  print(f"Request in {stop_time - start_time}s")
94
+