jokerbit commited on
Commit
668f50f
·
verified ·
1 Parent(s): 19023b1

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -3
src/pipeline.py CHANGED
@@ -11,7 +11,6 @@ from torch import Generator
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  import torch._dynamo
14
- import torch_tensorrt
15
  # torch._dynamo.config.suppress_errors = True
16
 
17
  Pipeline: TypeAlias = FluxPipeline
@@ -47,11 +46,11 @@ def load_pipeline() -> Pipeline:
47
  local_files_only=True,
48
  torch_dtype=torch.bfloat16,
49
  ).to("cuda")
50
- # pipeline.to(memory_format=torch.channels_last)
51
  # pipeline.transformer.to(memory_format=torch.channels_last)
52
  # pipeline.vae.to(memory_format=torch.channels_last)
53
 
54
- quantize_(pipeline.vae, int8_weight_only())
55
 
56
  PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
57
  with torch.inference_mode():
 
11
  from torchao.quantization import quantize_, int8_weight_only
12
  from transformers import T5EncoderModel, CLIPTextModel, logging
13
  import torch._dynamo
 
14
  # torch._dynamo.config.suppress_errors = True
15
 
16
  Pipeline: TypeAlias = FluxPipeline
 
46
  local_files_only=True,
47
  torch_dtype=torch.bfloat16,
48
  ).to("cuda")
49
+ pipeline.to(memory_format=torch.channels_last)
50
  # pipeline.transformer.to(memory_format=torch.channels_last)
51
  # pipeline.vae.to(memory_format=torch.channels_last)
52
 
53
+ # quantize_(pipeline.vae, int8_weight_only())
54
 
55
  PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
56
  with torch.inference_mode():