File size: 8,801 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
"""Finetuning methods."""
import logging
import os
import torch
from collections import OrderedDict
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import torch_load
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.dynamic_import import dynamic_import
def freeze_modules(model, modules):
"""Freeze model parameters according to modules list.
Args:
model (torch.nn.Module): main model to update
modules (list): specified module list for freezing
Return:
model (torch.nn.Module): updated model
model_params (filter): filtered model parameters
"""
for mod, param in model.named_parameters():
if any(mod.startswith(m) for m in modules):
logging.info(f"freezing {mod}, it will not be updated.")
param.requires_grad = False
model_params = filter(lambda x: x.requires_grad, model.parameters())
return model, model_params
def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (OrderedDict): the initial model state_dict
partial_state_dict (OrderedDict): the trained model state_dict
modules (list): specified module list for transfer
Return:
(boolean): allow transfer
"""
modules_model = []
partial_modules = []
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
partial_modules += [(key_p, value_p.shape)]
for key_m, value_m in model_state_dict.items():
if any(key_m.startswith(m) for m in modules):
modules_model += [(key_m, value_m.shape)]
len_match = len(modules_model) == len(partial_modules)
module_match = sorted(modules_model, key=lambda x: (x[0], x[1])) == sorted(
partial_modules, key=lambda x: (x[0], x[1])
)
return len_match and module_match
def get_partial_state_dict(model_state_dict, modules):
"""Create state_dict with specified modules matching input model modules.
Note that get_partial_lm_state_dict is used if a LM specified.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_state_dict (OrderedDict): the updated state_dict
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
def get_lm_state_dict(lm_state_dict):
"""Create compatible ASR decoder state dict from LM state dict.
Args:
lm_state_dict (OrderedDict): pre-trained LM state_dict
Return:
new_state_dict (OrderedDict): LM state_dict with updated keys
"""
new_state_dict = OrderedDict()
for key, value in list(lm_state_dict.items()):
if key == "predictor.embed.weight":
new_state_dict["dec.embed.weight"] = value
elif key.startswith("predictor.rnn."):
_split = key.split(".")
new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
new_state_dict[new_key] = value
return new_state_dict
def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in module_state_dict.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_mods (list): the update module list
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.warning(
"module(s) %s don't match or (partially match) "
"available modules in model.",
incorrect_mods,
)
logging.warning("for information, the existing modules in model are:")
logging.warning("%s", mods_model)
return new_mods
def load_trained_model(model_path, training=True):
"""Load the trained model for recognition.
Args:
model_path (str): Path to model.***.best
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), "model.json")
)
logging.warning("reading model parameters from " + model_path)
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
# CTC Loss is not needed, default to builtin to prevent import errors
if hasattr(train_args, "ctc_type"):
train_args.ctc_type = "builtin"
model_class = dynamic_import(model_module)
if "transducer" in model_module:
model = model_class(idim, odim, train_args, training=training)
custom_torch_load(model_path, model, training=training)
else:
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
def get_trained_model_state_dict(model_path):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to model.***.best
Return:
model.state_dict() (OrderedDict): the loaded model state_dict
(bool): Boolean defining whether the model is an LM
"""
conf_path = os.path.join(os.path.dirname(model_path), "model.json")
if "rnnlm" in model_path:
logging.warning("reading model parameters from %s", model_path)
return get_lm_state_dict(torch.load(model_path))
idim, odim, args = get_model_conf(model_path, conf_path)
logging.warning("reading model parameters from " + model_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert (
isinstance(model, MTInterface)
or isinstance(model, ASRInterface)
or isinstance(model, TTSInterface)
)
return model.state_dict()
def load_trained_modules(idim, odim, args, interface=ASRInterface):
"""Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
Args:
idim (int): initial input dimension.
odim (int): initial output dimension.
args (Namespace): The initial model arguments.
interface (Interface): ASRInterface or STInterface or TTSInterface.
Return:
model (torch.nn.Module): The model with pretrained modules.
"""
def print_new_keys(state_dict, modules, model_path):
logging.warning("loading %s from model: %s", modules, model_path)
for k in state_dict.keys():
logging.warning("override %s" % k)
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, interface)
main_state_dict = main_model.state_dict()
logging.warning("model(s) found for pre-initialization")
for model_path, modules in [
(enc_model_path, enc_modules),
(dec_model_path, dec_modules),
]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict = get_trained_model_state_dict(model_path)
modules = filter_modules(model_state_dict, modules)
partial_state_dict = get_partial_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(
main_state_dict, partial_state_dict, modules
):
print_new_keys(partial_state_dict, modules, model_path)
main_state_dict.update(partial_state_dict)
else:
logging.warning(
f"modules {modules} in model {model_path} "
f"don't match your training config",
)
else:
logging.warning("model was not found : %s", model_path)
main_model.load_state_dict(main_state_dict)
return main_model
|