|
from pathlib import Path |
|
|
|
import torch |
|
|
|
import s3prl |
|
from s3prl.upstream.data2vec.data2vec_model import ( |
|
Data2VecAudioConfig, |
|
Data2VecAudioModel, |
|
) |
|
from s3prl.upstream.utils import load_fairseq_ckpt, merge_with_parent |
|
from s3prl.upstream.wav2vec2.wav2vec2_model import AudioPretrainingConfig |
|
|
|
|
|
def load_and_convert_fairseq_ckpt(fairseq_source: str, output_path: str): |
|
state, cfg = load_fairseq_ckpt(fairseq_source) |
|
output_state = { |
|
"task_cfg": cfg["task"], |
|
"model_cfg": cfg["model"], |
|
"model_weight": state["model"], |
|
} |
|
|
|
Path(output_path).parent.mkdir(exist_ok=True, parents=True) |
|
torch.save(output_state, output_path) |
|
|
|
|
|
load_converted_model(output_path) |
|
|
|
|
|
def load_converted_model(ckpt: str): |
|
ckpt_state = torch.load(ckpt, map_location="cpu") |
|
|
|
for required_key in ["task_cfg", "model_cfg", "model_weight"]: |
|
if required_key not in ckpt_state: |
|
raise ValueError( |
|
f"{ckpt} is not a valid checkpoint since the required key: {required_key} is missing" |
|
) |
|
|
|
task_cfg = merge_with_parent(AudioPretrainingConfig, ckpt_state["task_cfg"]) |
|
model_cfg = merge_with_parent(Data2VecAudioConfig, ckpt_state["model_cfg"]) |
|
model = Data2VecAudioModel(model_cfg) |
|
|
|
model.remove_pretraining_modules() |
|
del ckpt_state["model_weight"]["_ema"] |
|
|
|
model.load_state_dict(ckpt_state["model_weight"]) |
|
return model, task_cfg |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("fairseq_ckpt") |
|
parser.add_argument( |
|
"--output_dir", default=Path(s3prl.__file__).parent.parent / "converted_ckpts" |
|
) |
|
args = parser.parse_args() |
|
|
|
Path(args.output_dir).parent.mkdir(exist_ok=True, parents=True) |
|
load_and_convert_fairseq_ckpt( |
|
args.fairseq_ckpt, Path(args.output_dir) / f"{Path(args.fairseq_ckpt).stem}.pt" |
|
) |
|
|