uetasr / model.py
thanhtvt
remove alsd
4ac7ffc
raw
history blame
4.11 kB
import os
import tensorflow as tf
from functools import lru_cache
from huggingface_hub import hf_hub_download
from hyperpyyaml import load_hyperpyyaml
from typing import Union
from decode import get_searcher
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def _get_checkpoint_filename(
repo_id: str,
filename: str,
local_dir: str = None,
local_dir_use_symlinks: Union[bool, str] = "auto",
subfolder: str = "checkpoints"
) -> str:
model_filename = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
)
return model_filename
def _get_bpe_model_filename(
repo_id: str,
filename: str,
local_dir: str = None,
local_dir_use_symlinks: Union[bool, str] = "auto",
subfolder: str = "vocabs"
) -> str:
bpe_model_filename = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
)
return bpe_model_filename
@lru_cache(maxsize=1)
def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"):
for postfix in ["index", "data-00000-of-00001"]:
tmp = _get_checkpoint_filename(
repo_id=repo_id,
filename="avg_top5_27-32.ckpt.{}".format(postfix),
subfolder=checkpoint_dir,
local_dir=os.path.dirname(__file__), # noqa
local_dir_use_symlinks=True,
)
print(tmp)
for postfix in ["model", "vocab"]:
tmp = _get_bpe_model_filename(
repo_id=repo_id,
filename="subword_vietnamese_500.{}".format(postfix),
local_dir=os.path.dirname(__file__), # noqa
local_dir_use_symlinks=True,
)
print(tmp)
config_path = hf_hub_download(
repo_id=repo_id,
filename="config.yaml",
local_dir=os.path.dirname(__file__), # noqa
local_dir_use_symlinks=True,
)
with open(config_path, "r") as f:
config = load_hyperpyyaml(f)
encoder_model = config["encoder_model"]
text_encoder = config["text_encoder"]
jointer = config["jointer_model"]
decoder = config["decoder_model"]
# searcher = config["decoder"]
model = config["model"]
audio_encoder = config["audio_encoder"]
model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial()
return audio_encoder, encoder_model, jointer, decoder, text_encoder, model
def read_audio(in_filename: str):
audio = tf.io.read_file(in_filename)
audio = tf.audio.decode_wav(audio)[0]
audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0)
return audio
class UETASRModel:
def __init__(
self,
repo_id: str,
decoding_method: str,
beam_size: int,
max_symbols_per_step: int,
):
self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id)
self.searcher = get_searcher(
decoding_method,
decoder,
jointer,
text_encoder,
beam_size,
max_symbols_per_step,
)
def predict(self, in_filename: str):
inputs = read_audio(in_filename)
features = self.featurizer(inputs)
features = self.model.cmvn(features) if self.model.use_cmvn else features
mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1])
mask = tf.expand_dims(mask, axis=1)
encoder_outputs, encoder_masks = self.encoder_model(
features, mask, training=False)
encoder_mask = tf.squeeze(encoder_masks, axis=1)
features_length = tf.math.reduce_sum(
tf.cast(encoder_mask, tf.int32),
axis=1
)
outputs = self.searcher.infer(encoder_outputs, features_length)
outputs = tf.squeeze(outputs)
outputs = tf.compat.as_str_any(outputs.numpy())
return outputs