File size: 1,010 Bytes
e9812a3 4ac7ffc e9812a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import logging
import tensorflow as tf
from functools import lru_cache
from uetasr.searchers import GreedyRNNT, BeamRNNT
@lru_cache(maxsize=5)
def get_searcher(
searcher_type: str,
decoder: tf.keras.Model,
jointer: tf.keras.Model,
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
beam_size: int,
max_symbols_per_step: int,
):
common_kwargs = {
"decoder": decoder,
"jointer": jointer,
"text_decoder": text_decoder,
"return_scores": False,
}
if searcher_type == "greedy_search":
searcher = GreedyRNNT(
max_symbols_per_step=max_symbols_per_step,
**common_kwargs,
)
elif searcher_type == "beam_search":
searcher = BeamRNNT(
max_symbols_per_step=max_symbols_per_step,
beam=beam_size,
alpha=0.0,
**common_kwargs,
)
else:
logging.info(f"Unknown searcher type: {searcher_type}")
return searcher
|