lmzjms's picture
Upload 1162 files
0b32ad6 verified
from pathlib import Path
import torch
import s3prl
from s3prl.upstream.data2vec.data2vec_model import (
Data2VecAudioConfig,
Data2VecAudioModel,
)
from s3prl.upstream.utils import load_fairseq_ckpt, merge_with_parent
from s3prl.upstream.wav2vec2.wav2vec2_model import AudioPretrainingConfig
def load_and_convert_fairseq_ckpt(fairseq_source: str, output_path: str):
state, cfg = load_fairseq_ckpt(fairseq_source)
output_state = {
"task_cfg": cfg["task"],
"model_cfg": cfg["model"],
"model_weight": state["model"],
}
Path(output_path).parent.mkdir(exist_ok=True, parents=True)
torch.save(output_state, output_path)
# make sure can load
load_converted_model(output_path)
def load_converted_model(ckpt: str):
ckpt_state = torch.load(ckpt, map_location="cpu")
for required_key in ["task_cfg", "model_cfg", "model_weight"]:
if required_key not in ckpt_state:
raise ValueError(
f"{ckpt} is not a valid checkpoint since the required key: {required_key} is missing"
)
task_cfg = merge_with_parent(AudioPretrainingConfig, ckpt_state["task_cfg"])
model_cfg = merge_with_parent(Data2VecAudioConfig, ckpt_state["model_cfg"])
model = Data2VecAudioModel(model_cfg)
model.remove_pretraining_modules()
del ckpt_state["model_weight"]["_ema"]
model.load_state_dict(ckpt_state["model_weight"])
return model, task_cfg
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("fairseq_ckpt")
parser.add_argument(
"--output_dir", default=Path(s3prl.__file__).parent.parent / "converted_ckpts"
)
args = parser.parse_args()
Path(args.output_dir).parent.mkdir(exist_ok=True, parents=True)
load_and_convert_fairseq_ckpt(
args.fairseq_ckpt, Path(args.output_dir) / f"{Path(args.fairseq_ckpt).stem}.pt"
)