|
import logging |
|
import tensorflow as tf |
|
from functools import lru_cache |
|
from uetasr.searchers import GreedyRNNT, BeamRNNT |
|
|
|
|
|
@lru_cache(maxsize=5) |
|
def get_searcher( |
|
searcher_type: str, |
|
decoder: tf.keras.Model, |
|
jointer: tf.keras.Model, |
|
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer, |
|
beam_size: int, |
|
max_symbols_per_step: int, |
|
): |
|
common_kwargs = { |
|
"decoder": decoder, |
|
"jointer": jointer, |
|
"text_decoder": text_decoder, |
|
"return_scores": False, |
|
} |
|
if searcher_type == "greedy_search": |
|
searcher = GreedyRNNT( |
|
max_symbols_per_step=max_symbols_per_step, |
|
**common_kwargs, |
|
) |
|
elif searcher_type == "beam_search": |
|
searcher = BeamRNNT( |
|
max_symbols_per_step=max_symbols_per_step, |
|
beam=beam_size, |
|
alpha=0.0, |
|
**common_kwargs, |
|
) |
|
else: |
|
logging.info(f"Unknown searcher type: {searcher_type}") |
|
|
|
return searcher |
|
|