File size: 1,680 Bytes
0209786 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import torch
def convert(checkpoint: str, outdir: str, suffix: str = "base"):
"""Convert the checkpoint to generator and detector"""
outdir_path = Path(outdir)
ckpt = torch.load(checkpoint)
# keep inference-related params only
infer_cfg = {
"seanet": ckpt["xp.cfg"]["seanet"],
"channels": ckpt["xp.cfg"]["channels"],
"dtype": ckpt["xp.cfg"]["dtype"],
"sample_rate": ckpt["xp.cfg"]["sample_rate"],
}
generator_ckpt = {"xp.cfg": infer_cfg, "model": {}}
detector_ckpt = {"xp.cfg": infer_cfg, "model": {}}
for layer in ckpt["model"].keys():
if layer.startswith("detector"):
new_layer = layer[9:]
detector_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
elif layer == "msg_processor.msg_processor.0.weight":
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ # type: ignore
"model"
][
layer
]
else:
assert layer.startswith("generator"), f"Invalid layer: {layer}"
new_layer = layer[10:]
generator_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth"))
torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth"))
if __name__ == "__main__":
import fire
fire.Fire(convert)
|