manbeast3b commited on
Commit
136c023
·
verified ·
1 Parent(s): eddb286

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +7 -7
src/pipeline.py CHANGED
@@ -589,12 +589,13 @@ def load_pipeline() -> Pipeline:
589
  empty_cache()
590
 
591
  dtype, device = torch.bfloat16, "cuda"
592
-
593
  text_encoder_2 = T5EncoderModel.from_pretrained(
594
  "silentdriver/aadb864af9", revision = "060dabc7fa271c26dfa3fd43c16e7c5bf3ac7892", torch_dtype=torch.bfloat16
595
  ).to(memory_format=torch.channels_last)
596
 
597
  vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
 
 
598
 
599
  path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
600
  generator = torch.Generator(device=device)
@@ -610,12 +611,11 @@ def load_pipeline() -> Pipeline:
610
  torch_dtype=dtype,
611
  ).to(device)
612
 
613
- pipeline.text_encoder.to(memory_format=torch.channels_last)
614
- pipeline.text_encoder_2.to(memory_format=torch.channels_last)
615
- pipeline.transformer.to(memory_format=torch.channels_last)
616
- pipeline.vae.to(memory_format=torch.channels_last)
617
-
618
- for _ in range(3):
619
  pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
620
 
621
  empty_cache()
 
589
  empty_cache()
590
 
591
  dtype, device = torch.bfloat16, "cuda"
 
592
  text_encoder_2 = T5EncoderModel.from_pretrained(
593
  "silentdriver/aadb864af9", revision = "060dabc7fa271c26dfa3fd43c16e7c5bf3ac7892", torch_dtype=torch.bfloat16
594
  ).to(memory_format=torch.channels_last)
595
 
596
  vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
597
+ vae.encoder.load_state_dict(torch.load("encoder.pth"))
598
+ vae.decoder.load_state_dict(torch.load("decoder.pth"))
599
 
600
  path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
601
  generator = torch.Generator(device=device)
 
611
  torch_dtype=dtype,
612
  ).to(device)
613
 
614
+ # Optimize memory format
615
+ for component in [pipeline.text_encoder, pipeline.text_encoder_2, pipeline.transformer, pipeline.vae]:
616
+ component.to(memory_format=torch.channels_last)
617
+
618
+ for _ in range(2):
 
619
  pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
620
 
621
  empty_cache()