import logging import tensorflow as tf from functools import lru_cache from uetasr.searchers import GreedyRNNT, BeamRNNT @lru_cache(maxsize=5) def get_searcher( searcher_type: str, decoder: tf.keras.Model, jointer: tf.keras.Model, text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer, beam_size: int, max_symbols_per_step: int, ): common_kwargs = { "decoder": decoder, "jointer": jointer, "text_decoder": text_decoder, "return_scores": False, } if searcher_type == "greedy_search": searcher = GreedyRNNT( max_symbols_per_step=max_symbols_per_step, **common_kwargs, ) elif searcher_type == "beam_search": searcher = BeamRNNT( max_symbols_per_step=max_symbols_per_step, beam=beam_size, alpha=0.0, **common_kwargs, ) else: logging.info(f"Unknown searcher type: {searcher_type}") return searcher