jokerbit commited on
Commit
8498d2f
·
verified ·
1 Parent(s): 05d51ec

Upload src/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipeline.py +23 -10
src/pipeline.py CHANGED
@@ -10,15 +10,18 @@ from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
-
 
14
 
15
  Pipeline: TypeAlias = FluxPipeline
16
- torch.backends.cudnn.benchmark = True
17
  torch.backends.cudnn.benchmark = True
18
  torch._inductor.config.conv_1x1_as_mm = True
19
  torch._inductor.config.coordinate_descent_tuning = True
20
  torch._inductor.config.epilogue_fusion = False
21
  torch._inductor.config.coordinate_descent_check_all_directions = True
 
 
22
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
23
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
24
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
@@ -29,29 +32,38 @@ TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
29
 
30
 
31
  def load_pipeline() -> Pipeline:
32
- path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
33
  transformer = FluxTransformer2DModel.from_pretrained(
34
  path,
35
  use_safetensors=False,
36
  local_files_only=True,
37
  torch_dtype=torch.bfloat16)
 
 
 
 
 
38
 
39
  pipeline = FluxPipeline.from_pretrained(
40
  CHECKPOINT,
41
  revision=REVISION,
42
  transformer=transformer,
 
43
  local_files_only=True,
44
  torch_dtype=torch.bfloat16,
45
  ).to("cuda")
46
 
47
  pipeline.transformer.to(memory_format=torch.channels_last)
48
- # pipeline.vae.to(memory_format=torch.channels_last)
49
- # quantize_(pipeline.vae, int8_weight_only())
 
50
  pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, mode="max-autotune")
51
-
52
- for _ in range(2):
53
- pipeline("cat", num_inference_steps=4)
54
-
 
 
55
  return pipeline
56
 
57
  @torch.inference_mode()
@@ -75,13 +87,14 @@ if __name__ == "__main__":
75
  height=None,
76
  width=None,
77
  seed=666)
 
78
  start_time = perf_counter()
79
  pipe_ = load_pipeline()
80
  stop_time = perf_counter()
81
  print(f"Pipeline is loaded in {stop_time - start_time}s")
82
  for _ in range(4):
83
  start_time = perf_counter()
84
- infer(request, pipe_)
85
  stop_time = perf_counter()
86
  print(f"Request in {stop_time - start_time}s")
87
 
 
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
+ import torch._dynamo
14
+ torch._dynamo.config.suppress_errors = True
15
 
16
  Pipeline: TypeAlias = FluxPipeline
17
+
18
  torch.backends.cudnn.benchmark = True
19
  torch._inductor.config.conv_1x1_as_mm = True
20
  torch._inductor.config.coordinate_descent_tuning = True
21
  torch._inductor.config.epilogue_fusion = False
22
  torch._inductor.config.coordinate_descent_check_all_directions = True
23
+
24
+
25
  os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
26
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
27
  CHECKPOINT = "jokerbit/flux.1-schnell-Robert-int8wo"
 
32
 
33
 
34
  def load_pipeline() -> Pipeline:
35
+ path = os.path.join(HF_HUB_CACHE, "models--jokerbit--flux.1-schnell-Robert-int8wo/snapshots/5ef0012f11a863e5111ec56540302a023bc8587b/transformer")
36
  transformer = FluxTransformer2DModel.from_pretrained(
37
  path,
38
  use_safetensors=False,
39
  local_files_only=True,
40
  torch_dtype=torch.bfloat16)
41
+ vae = AutoencoderTiny.from_pretrained(
42
+ TinyVAE,
43
+ TinyVAE_REV,
44
+ local_files_only=True,
45
+ torch_dtype=torch.bfloat16)
46
 
47
  pipeline = FluxPipeline.from_pretrained(
48
  CHECKPOINT,
49
  revision=REVISION,
50
  transformer=transformer,
51
+ vae=vae,
52
  local_files_only=True,
53
  torch_dtype=torch.bfloat16,
54
  ).to("cuda")
55
 
56
  pipeline.transformer.to(memory_format=torch.channels_last)
57
+ pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=False)
58
+ pipeline.vae.to(memory_format=torch.channels_last)
59
+ quantize_(pipeline.vae, int8_weight_only())
60
  pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, mode="max-autotune")
61
+
62
+ PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
63
+ with torch.inference_mode():
64
+ for _ in range(4):
65
+ pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
66
+ torch.cuda.empty_cache()
67
  return pipeline
68
 
69
  @torch.inference_mode()
 
87
  height=None,
88
  width=None,
89
  seed=666)
90
+ generator = torch.Generator(device="cuda")
91
  start_time = perf_counter()
92
  pipe_ = load_pipeline()
93
  stop_time = perf_counter()
94
  print(f"Pipeline is loaded in {stop_time - start_time}s")
95
  for _ in range(4):
96
  start_time = perf_counter()
97
+ infer(request, pipe_, generator=generator.manual_seed(request.seed))
98
  stop_time = perf_counter()
99
  print(f"Request in {stop_time - start_time}s")
100