jokerbit commited on
Commit
3606bd6
·
verified ·
1 Parent(s): 95e723e

Upload src/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pipeline.py +4 -2
src/pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- import gc
2
  import os
3
  from typing import TypeAlias
4
 
@@ -10,7 +10,9 @@ from pipelines.models import TextToImageRequest
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
 
13
 
 
14
 
15
  Pipeline: TypeAlias = FluxPipeline
16
  torch.backends.cudnn.benchmark = True
@@ -52,7 +54,7 @@ def load_pipeline() -> Pipeline:
52
  pipeline.transformer.to(memory_format=torch.channels_last)
53
  pipeline.vae.to(memory_format=torch.channels_last)
54
  quantize_(pipeline.vae, int8_weight_only())
55
- pipeline.vae = torch.compile(pipeline.vae, mode="reduce-overhead", fullgraph=True)
56
  pipeline.to("cuda")
57
 
58
  for _ in range(2):
 
1
+ # import torch_tensorrt
2
  import os
3
  from typing import TypeAlias
4
 
 
10
  from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
+ from functools import partial
14
 
15
+ my_overhead_compile = partial(torch.compile, mode="reduce-overhead", fullgraph=True)
16
 
17
  Pipeline: TypeAlias = FluxPipeline
18
  torch.backends.cudnn.benchmark = True
 
54
  pipeline.transformer.to(memory_format=torch.channels_last)
55
  pipeline.vae.to(memory_format=torch.channels_last)
56
  quantize_(pipeline.vae, int8_weight_only())
57
+ pipeline.vae = my_overhead_compile(pipeline.vae)
58
  pipeline.to("cuda")
59
 
60
  for _ in range(2):