Spaces:
Runtime error
Runtime error
import logging | |
import tensorflow as tf | |
from functools import lru_cache | |
from uetasr.searchers import GreedyRNNT, BeamRNNT | |
def get_searcher( | |
searcher_type: str, | |
decoder: tf.keras.Model, | |
jointer: tf.keras.Model, | |
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer, | |
beam_size: int, | |
max_symbols_per_step: int, | |
): | |
common_kwargs = { | |
"decoder": decoder, | |
"jointer": jointer, | |
"text_decoder": text_decoder, | |
"return_scores": False, | |
} | |
if searcher_type == "greedy_search": | |
searcher = GreedyRNNT( | |
max_symbols_per_step=max_symbols_per_step, | |
**common_kwargs, | |
) | |
elif searcher_type == "beam_search": | |
searcher = BeamRNNT( | |
max_symbols_per_step=max_symbols_per_step, | |
beam=beam_size, | |
alpha=0.0, | |
**common_kwargs, | |
) | |
else: | |
logging.info(f"Unknown searcher type: {searcher_type}") | |
return searcher | |