uetasr / decode.py
thanhtvt
remove alsd
4ac7ffc
import logging
import tensorflow as tf
from functools import lru_cache
from uetasr.searchers import GreedyRNNT, BeamRNNT
@lru_cache(maxsize=5)
def get_searcher(
searcher_type: str,
decoder: tf.keras.Model,
jointer: tf.keras.Model,
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
beam_size: int,
max_symbols_per_step: int,
):
common_kwargs = {
"decoder": decoder,
"jointer": jointer,
"text_decoder": text_decoder,
"return_scores": False,
}
if searcher_type == "greedy_search":
searcher = GreedyRNNT(
max_symbols_per_step=max_symbols_per_step,
**common_kwargs,
)
elif searcher_type == "beam_search":
searcher = BeamRNNT(
max_symbols_per_step=max_symbols_per_step,
beam=beam_size,
alpha=0.0,
**common_kwargs,
)
else:
logging.info(f"Unknown searcher type: {searcher_type}")
return searcher