jokerbit commited on
Commit
a5f3be1
·
verified ·
1 Parent(s): afb0eb8
Files changed (1) hide show
  1. src/pipeline.py +2 -2
src/pipeline.py CHANGED
@@ -40,9 +40,9 @@ def load_pipeline() -> Pipeline:
40
  transformer=transformer,
41
  local_files_only=True,
42
  torch_dtype=torch.bfloat16,
43
- ).to("cuda")
44
  quantize_(pipeline.vae, int8_weight_only())
45
- pipeline.to(memory_format=torch.channels_last)
46
  # pipeline.transformer.to(memory_format=torch.channels_last)
47
 
48
  PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'
 
40
  transformer=transformer,
41
  local_files_only=True,
42
  torch_dtype=torch.bfloat16,
43
+ )
44
  quantize_(pipeline.vae, int8_weight_only())
45
+ pipeline.to(memory_format=torch.channels_last, device="cuda")
46
  # pipeline.transformer.to(memory_format=torch.channels_last)
47
 
48
  PROMPT = 'semiconformity, peregrination, quip, twineless, emotionless, tawa, depickle'