jokerbit commited on
Commit
802934b
·
verified ·
1 Parent(s): 50477bc

1 less line

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -4
src/pipeline.py CHANGED
@@ -47,13 +47,11 @@ def load_pipeline() -> Pipeline:
47
  vae=vae,
48
  local_files_only=True,
49
  torch_dtype=torch.bfloat16,
50
- )
51
 
52
- pipeline.transformer.to(memory_format=torch.channels_last)
53
- pipeline.vae.to(memory_format=torch.channels_last)
54
  quantize_(pipeline.vae, int8_weight_only())
55
  pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
56
- pipeline.to("cuda")
57
 
58
  for _ in range(2):
59
  pipeline("cat", num_inference_steps=4)
 
47
  vae=vae,
48
  local_files_only=True,
49
  torch_dtype=torch.bfloat16,
50
+ ).to("cuda")
51
 
52
+ pipeline.to(memory_format=torch.channels_last)
 
53
  quantize_(pipeline.vae, int8_weight_only())
54
  pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
 
55
 
56
  for _ in range(2):
57
  pipeline("cat", num_inference_steps=4)