jokerbit commited on
Commit
4d1a25f
·
verified ·
1 Parent(s): 704b49a

Channels last

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -1
src/pipeline.py CHANGED
@@ -32,7 +32,7 @@ def load_pipeline() -> Pipeline:
32
  path,
33
  use_safetensors=False,
34
  local_files_only=True,
35
- torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
36
 
37
  pipeline = FluxPipeline.from_pretrained(
38
  CHECKPOINT,
@@ -41,6 +41,7 @@ def load_pipeline() -> Pipeline:
41
  local_files_only=True,
42
  torch_dtype=torch.bfloat16,
43
  )
 
44
  pipeline.to("cuda")
45
  for _ in range(4):
46
  pipeline("cat", num_inference_steps=4)
 
32
  path,
33
  use_safetensors=False,
34
  local_files_only=True,
35
+ torch_dtype=torch.bfloat16)
36
 
37
  pipeline = FluxPipeline.from_pretrained(
38
  CHECKPOINT,
 
41
  local_files_only=True,
42
  torch_dtype=torch.bfloat16,
43
  )
44
+ pipeline.transformer.to(memory_format=torch.channels_last)
45
  pipeline.to("cuda")
46
  for _ in range(4):
47
  pipeline("cat", num_inference_steps=4)