manbeast3b commited on
Commit
b08b5a9
·
verified ·
1 Parent(s): ba4e88b

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -2
src/pipeline.py CHANGED
@@ -594,8 +594,8 @@ def load_pipeline() -> Pipeline:
594
  ).to(memory_format=torch.channels_last)
595
 
596
  vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
597
- vae.encoder.load_state_dict(torch.load("encoder.pth"))
598
- vae.decoder.load_state_dict(torch.load("decoder.pth"))
599
 
600
  path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
601
  generator = torch.Generator(device=device)
 
594
  ).to(memory_format=torch.channels_last)
595
 
596
  vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
597
+ vae.encoder.load_state_dict(torch.load("encoder.pth"), strict=False)
598
+ vae.decoder.load_state_dict(torch.load("decoder.pth"), strict=False)
599
 
600
  path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
601
  generator = torch.Generator(device=device)