jokerbit commited on
Commit
0e9bb8a
·
verified ·
1 Parent(s): 842a5e2

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -3
src/pipeline.py CHANGED
@@ -11,7 +11,7 @@ from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  import torch._dynamo
14
- import torch_tensorrt
15
  torch._dynamo.config.suppress_errors = True
16
 
17
  Pipeline: TypeAlias = FluxPipeline
@@ -54,8 +54,8 @@ def load_pipeline() -> Pipeline:
54
  torch_dtype=torch.bfloat16,
55
  ).to("cuda")
56
 
57
- pipeline.transformer.to(memory_format=torch.channels_last)
58
- pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
59
  quantize_(pipeline.vae, int8_weight_only())
60
  # pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
61
  # pipeline.set_progress_bar_config(disable=True)
 
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  import torch._dynamo
14
+
15
  torch._dynamo.config.suppress_errors = True
16
 
17
  Pipeline: TypeAlias = FluxPipeline
 
54
  torch_dtype=torch.bfloat16,
55
  ).to("cuda")
56
 
57
+ pipeline.to(memory_format=torch.channels_last)
58
+ pipeline.transformer = torch.compile(pipeline.transformer)
59
  quantize_(pipeline.vae, int8_weight_only())
60
  # pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
61
  # pipeline.set_progress_bar_config(disable=True)