jokerbit commited on
Commit
626e60b
·
verified ·
1 Parent(s): cbebdce

Upload src/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -2
src/pipeline.py CHANGED
@@ -11,7 +11,8 @@ from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  import torch._dynamo
14
- torch._dynamo.config.suppress_errors = True
 
15
 
16
  Pipeline: TypeAlias = FluxPipeline
17
 
@@ -50,7 +51,7 @@ def load_pipeline() -> Pipeline:
50
  pipeline.transformer.to(memory_format=torch.channels_last)
51
  pipeline.vae.to(memory_format=torch.channels_last)
52
  quantize_(pipeline.vae, int8_weight_only())
53
- pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, mode="max-autotune")
54
 
55
  PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
56
  with torch.inference_mode():
 
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  import torch._dynamo
14
+ import torch_tensorrt
15
+ # torch._dynamo.config.suppress_errors = True
16
 
17
  Pipeline: TypeAlias = FluxPipeline
18
 
 
51
  pipeline.transformer.to(memory_format=torch.channels_last)
52
  pipeline.vae.to(memory_format=torch.channels_last)
53
  quantize_(pipeline.vae, int8_weight_only())
54
+ pipeline.vae = torch.compile(pipeline.vae, fullgraph=True, backend="tensorrt")
55
 
56
  PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
57
  with torch.inference_mode():