Update src/pipeline.py
Browse files- src/pipeline.py +2 -2
src/pipeline.py
CHANGED
@@ -594,8 +594,8 @@ def load_pipeline() -> Pipeline:
|
|
594 |
).to(memory_format=torch.channels_last)
|
595 |
|
596 |
vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
|
597 |
-
vae.encoder.load_state_dict(torch.load("encoder.pth"))
|
598 |
-
vae.decoder.load_state_dict(torch.load("decoder.pth"))
|
599 |
|
600 |
path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
|
601 |
generator = torch.Generator(device=device)
|
|
|
594 |
).to(memory_format=torch.channels_last)
|
595 |
|
596 |
vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
|
597 |
+
vae.encoder.load_state_dict(torch.load("encoder.pth"), strict=False)
|
598 |
+
vae.decoder.load_state_dict(torch.load("decoder.pth"), strict=False)
|
599 |
|
600 |
path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
|
601 |
generator = torch.Generator(device=device)
|