Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Wenet Community. (authors: Xingchen Song) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Requirements: | |
```bash | |
pip install -U openai-whisper | |
``` | |
Example: | |
```bash | |
# Converts the model from OpenAI to WeNet format: | |
python convert_whisper_to_wenet_config_and_ckpt.py \ | |
--whisper_ckpt large-v3.pt \ | |
--output_dir exp/whisper/large-v3 | |
``` | |
""" | |
import argparse | |
import copy | |
import os | |
import sys | |
import torch | |
import yaml | |
_cpath_ = sys.path[0] | |
sys.path.remove(_cpath_) | |
from whisper.tokenizer import get_tokenizer | |
sys.path.insert(0, _cpath_) | |
def convert_to_wenet_yaml(tokenizer, dims, wenet_yaml_path: str): | |
configs = {} | |
configs['input_dim'] = dims['n_mels'] | |
configs['output_dim'] = dims['n_vocab'] | |
assert dims['n_vocab'] == tokenizer.encoding.n_vocab, "{} v.s. {}".format( | |
dims['n_vocab'], tokenizer.encoding.n_vocab) | |
configs['encoder'] = 'transformer' | |
configs['encoder_conf'] = {} | |
configs['encoder_conf']['gradient_checkpointing'] = True | |
configs['encoder_conf']['input_layer'] = 'conv1d2' | |
configs['encoder_conf']['output_size'] = dims['n_audio_state'] | |
configs['encoder_conf']['attention_heads'] = dims['n_audio_head'] | |
configs['encoder_conf']['linear_units'] = dims['n_audio_state'] * 4 | |
configs['encoder_conf']['num_blocks'] = dims['n_audio_layer'] | |
configs['encoder_conf']['dropout_rate'] = 0.1 | |
configs['encoder_conf']['positional_dropout_rate'] = 0.1 | |
configs['encoder_conf']['attention_dropout_rate'] = 0.0 | |
configs['encoder_conf']['normalize_before'] = True | |
configs['encoder_conf']['use_dynamic_chunk'] = False | |
configs['encoder_conf']['use_dynamic_left_chunk'] = False | |
configs['encoder_conf']['pos_enc_layer_type'] = "abs_pos_whisper" | |
configs['encoder_conf']['static_chunk_size'] = -1 | |
configs['encoder_conf']['key_bias'] = False | |
configs['encoder_conf']['activation_type'] = "gelu" | |
configs['decoder'] = 'transformer' | |
configs['decoder_conf'] = {} | |
configs['decoder_conf']['tie_word_embedding'] = True | |
configs['decoder_conf']['gradient_checkpointing'] = True | |
configs['decoder_conf']['attention_heads'] = dims['n_text_head'] | |
configs['decoder_conf']['linear_units'] = dims['n_text_state'] * 4 | |
configs['decoder_conf']['num_blocks'] = dims['n_text_layer'] | |
configs['decoder_conf']['dropout_rate'] = 0.1 | |
configs['decoder_conf']['positional_dropout_rate'] = 0.1 | |
configs['decoder_conf']['self_attention_dropout_rate'] = 0.0 | |
configs['decoder_conf']['src_attention_dropout_rate'] = 0.0 | |
configs['decoder_conf']['input_layer'] = "embed_learnable_pe" | |
configs['decoder_conf']['use_output_layer'] = True | |
configs['decoder_conf']['normalize_before'] = True | |
configs['decoder_conf']['src_attention'] = True | |
configs['decoder_conf']['key_bias'] = False | |
configs['decoder_conf']['activation_type'] = "gelu" | |
configs['tokenizer'] = 'whisper' | |
configs['tokenizer_conf'] = {} | |
configs['tokenizer_conf']['is_multilingual'] = dims['n_vocab'] >= 51865 | |
configs['tokenizer_conf']['num_languages'] = dims['n_vocab'] - 51765 - \ | |
int(configs['tokenizer_conf']['is_multilingual']) | |
configs['tokenizer_conf']['split_with_space'] = False | |
configs['tokenizer_conf']['bpe_path'] = None | |
configs['tokenizer_conf']['symbol_table_path'] = None | |
configs['tokenizer_conf']['non_lang_syms_path'] = None | |
configs['tokenizer_conf']['special_tokens'] = {} | |
configs['tokenizer_conf']['special_tokens']['sot'] = tokenizer.sot | |
configs['tokenizer_conf']['special_tokens']['eot'] = tokenizer.eot | |
configs['tokenizer_conf']['special_tokens'][ | |
'sot_prev'] = tokenizer.sot_prev | |
configs['tokenizer_conf']['special_tokens'][ | |
'transcribe'] = tokenizer.transcribe | |
configs['tokenizer_conf']['special_tokens'][ | |
'translate'] = tokenizer.translate | |
configs['tokenizer_conf']['special_tokens'][ | |
'no_timestamps'] = tokenizer.no_timestamps | |
configs['tokenizer_conf']['special_tokens'][ | |
'no_speech'] = tokenizer.no_speech | |
configs['tokenizer_conf']['special_tokens']['timestamp_begin'] = \ | |
tokenizer.timestamp_begin | |
configs['ctc_conf'] = {} | |
configs['ctc_conf']['ctc_blank_id'] = tokenizer.no_speech | |
configs['cmvn'] = None | |
configs['cmvn_conf'] = {} | |
configs['cmvn_conf']['cmvn_file'] = None | |
configs['cmvn_conf']['is_json_cmvn'] = None | |
configs['model'] = "whisper" | |
configs['model_conf'] = {} | |
configs['model_conf']['ctc_weight'] = 0.3 | |
configs['model_conf']['lsm_weight'] = 0.1 | |
configs['model_conf']['length_normalized_loss'] = False | |
configs['dataset'] = "asr" | |
configs['dataset_conf'] = {} | |
configs['dataset_conf']['filter_conf'] = {} | |
configs['dataset_conf']['filter_conf'][ | |
'max_length'] = dims['n_audio_ctx'] * 2 # 1/2 subsample # noqa | |
configs['dataset_conf']['filter_conf']['min_length'] = 0 | |
configs['dataset_conf']['filter_conf']['token_max_length'] = dims[ | |
'n_text_ctx'] | |
configs['dataset_conf']['filter_conf']['token_min_length'] = 1 | |
configs['dataset_conf']['resample_conf'] = {} | |
configs['dataset_conf']['resample_conf']['resample_rate'] = 16000 | |
# NOTE: Disable speed_perturb, https://github.com/wenet-e2e/wenet/issues/2171 | |
configs['dataset_conf']['speed_perturb'] = False | |
configs['dataset_conf']['spec_aug'] = True | |
configs['dataset_conf']['spec_aug_conf'] = {} | |
configs['dataset_conf']['spec_aug_conf']['num_t_mask'] = 2 | |
configs['dataset_conf']['spec_aug_conf']['num_f_mask'] = 2 | |
configs['dataset_conf']['spec_aug_conf']['max_t'] = 50 | |
configs['dataset_conf']['spec_aug_conf']['max_f'] = 10 | |
configs['dataset_conf']['spec_sub'] = True | |
configs['dataset_conf']['spec_sub_conf'] = {} | |
configs['dataset_conf']['spec_sub_conf']['num_t_sub'] = 3 | |
configs['dataset_conf']['spec_sub_conf']['max_t'] = 30 | |
configs['dataset_conf']['spec_trim'] = False | |
configs['dataset_conf']['shuffle'] = True | |
configs['dataset_conf']['shuffle_conf'] = {} | |
configs['dataset_conf']['shuffle_conf']['shuffle_size'] = 1500 | |
configs['dataset_conf']['sort'] = True | |
configs['dataset_conf']['sort_conf'] = {} | |
configs['dataset_conf']['sort_conf']['sort_size'] = 500 | |
configs['dataset_conf']['feats_type'] = "log_mel_spectrogram" | |
configs['dataset_conf']['log_mel_spectrogram_conf'] = {} | |
configs['dataset_conf']['log_mel_spectrogram_conf']['n_fft'] = 400 | |
configs['dataset_conf']['log_mel_spectrogram_conf']['hop_length'] = 160 | |
configs['dataset_conf']['log_mel_spectrogram_conf']['num_mel_bins'] = dims[ | |
'n_mels'] | |
configs['dataset_conf']['log_mel_spectrogram_conf']['padding'] = 0 | |
configs['dataset_conf']['batch_conf'] = {} | |
configs['dataset_conf']['batch_conf']['batch_type'] = 'dynamic' | |
configs['dataset_conf']['batch_conf']['batch_size'] = 26 | |
configs['dataset_conf']['batch_conf']['max_frames_in_batch'] = 12000 | |
configs['dataset_conf']['language_conf'] = {} | |
configs['dataset_conf']['language_conf']['limited_langs'] = ['zh'] | |
configs['grad_clip'] = 5 | |
configs['accum_grad'] = 4 | |
configs['max_epoch'] = 100 | |
configs['log_interval'] = 100 | |
configs['optim'] = "adam" | |
configs['optim_conf'] = {} | |
configs['optim_conf']['lr'] = 0.0005 | |
configs['scheduler'] = "warmuplr" | |
configs['scheduler_conf'] = {} | |
configs['scheduler_conf']['warmup_steps'] = 12000 | |
with open(wenet_yaml_path, '+w') as f: | |
f.write(yaml.dump(configs)) | |
f.flush() | |
print(configs) | |
def convert_to_wenet_state_dict(whisper_state_dict, wenet_state_dict_path): | |
wenet_state_dict = {} | |
unused = [] | |
print( | |
"===================== start CKPT Conversion =========================" | |
) | |
for name in whisper_state_dict.keys(): | |
original_name = copy.deepcopy(name) | |
name = name.replace("encoder.conv1", "encoder.embed.conv.0") | |
name = name.replace("encoder.conv2", "encoder.embed.conv.2") | |
name = name.replace("decoder.token_embedding", "decoder.embed.0") | |
name = name.replace("encoder.blocks", "encoder.encoders") | |
name = name.replace("decoder.blocks", "decoder.decoders") | |
name = name.replace(".cross_attn.query", ".src_attn.linear_q") | |
name = name.replace(".cross_attn.key", ".src_attn.linear_k") | |
name = name.replace(".cross_attn.value", ".src_attn.linear_v") | |
name = name.replace(".cross_attn.out", ".src_attn.linear_out") | |
name = name.replace(".attn.query", ".self_attn.linear_q") | |
name = name.replace(".attn.key", ".self_attn.linear_k") | |
name = name.replace(".attn.value", ".self_attn.linear_v") | |
name = name.replace(".attn.out", ".self_attn.linear_out") | |
name = name.replace("mlp.0", "feed_forward.w_1") | |
name = name.replace("mlp.2", "feed_forward.w_2") | |
if "decoder" in name: | |
name = name.replace("cross_attn_ln", "norm2") | |
name = name.replace("mlp_ln", "norm3") | |
else: | |
name = name.replace("mlp_ln", "norm2") | |
name = name.replace("attn_ln", "norm1") | |
name = name.replace("encoder.ln_post", "encoder.after_norm") | |
name = name.replace("decoder.ln", "decoder.after_norm") | |
if original_name == "decoder.positional_embedding": | |
whisper_state_dict[name] = whisper_state_dict[name].unsqueeze(0) | |
name = "decoder.embed.1.pe" | |
elif original_name == "encoder.positional_embedding": | |
whisper_state_dict[name] = whisper_state_dict[name].unsqueeze(0) | |
name = "encoder.embed.pos_enc.pe" | |
print("name {} ==> {}".format(original_name, name)) | |
print("type {} ==> torch.float32".format( | |
whisper_state_dict[original_name].dtype)) | |
print("shape {}\n".format(whisper_state_dict[original_name].shape)) | |
if (original_name == name): | |
unused.append(name) | |
else: | |
wenet_state_dict[name] = whisper_state_dict[original_name].float() | |
for name in unused: | |
print("NOTE!!! drop {}".format(name)) | |
print("Saving fp32 ckpt to {}...".format(wenet_state_dict_path)) | |
torch.save(wenet_state_dict, wenet_state_dict_path) | |
print( | |
"DONE\n===================== End CKPT Conversion =========================\n" | |
) | |
def convert_to_wenet_units(tokenizer, units_txt_path): | |
""" NOTE(xcsong): | |
The "units.txt" file is solely for adapting to the training API of Wenet | |
and for quickly checking the corresponding text of an ID when necessary. | |
It does not play any role in the tokenization process, | |
which is carried out by the tokenizer of openai-whisper. | |
""" | |
n_vocab = tokenizer.encoding.n_vocab | |
with open(units_txt_path, "+w") as f: | |
for i in range(n_vocab): | |
unit = str(tokenizer.encoding.decode_single_token_bytes(i)) | |
if len(unit) == 0: | |
unit = str(i) | |
print("can not decode id {}, convert to str({})".format(i, i)) | |
unit = unit.replace(" ", "<space>") | |
f.write("{} {}\n".format(unit, i)) | |
f.flush() | |
def get_args(): | |
parser = argparse.ArgumentParser(description='load and parse whisper') | |
# yapf: disable | |
parser.add_argument( | |
'--whisper_ckpt', | |
required=True, | |
help='https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt' # noqa | |
) | |
# yapf: enable | |
parser.add_argument('--output_dir', | |
default='.', | |
help='output file in wenet\'s style: ' + | |
'units.txt, train.yaml, model.pt') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = get_args() | |
checkpoint = torch.load(args.whisper_ckpt, map_location="cpu") | |
multilingual = checkpoint["dims"]['n_vocab'] >= 51865 | |
num_languages = checkpoint["dims"]['n_vocab'] - 51765 - int(multilingual) | |
tokenizer = get_tokenizer(multilingual=multilingual, | |
num_languages=num_languages) | |
convert_to_wenet_state_dict( | |
checkpoint["model_state_dict"], | |
os.path.join(args.output_dir, 'wenet_whisper.pt')) | |
convert_to_wenet_units(tokenizer, os.path.join(args.output_dir, | |
'units.txt')) | |
convert_to_wenet_yaml(tokenizer, checkpoint["dims"], | |
os.path.join(args.output_dir, 'train.yaml')) | |
if __name__ == "__main__": | |
main() | |