File size: 1,010 Bytes
e9812a3
 
 
4ac7ffc
e9812a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
import tensorflow as tf
from functools import lru_cache
from uetasr.searchers import GreedyRNNT, BeamRNNT


@lru_cache(maxsize=5)
def get_searcher(
    searcher_type: str,
    decoder: tf.keras.Model,
    jointer: tf.keras.Model,
    text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
    beam_size: int,
    max_symbols_per_step: int,
):
    common_kwargs = {
        "decoder": decoder,
        "jointer": jointer,
        "text_decoder": text_decoder,
        "return_scores": False,
    }
    if searcher_type == "greedy_search":
        searcher = GreedyRNNT(
            max_symbols_per_step=max_symbols_per_step,
            **common_kwargs,
        )
    elif searcher_type == "beam_search":
        searcher = BeamRNNT(
            max_symbols_per_step=max_symbols_per_step,
            beam=beam_size,
            alpha=0.0,
            **common_kwargs,
        )
    else:
        logging.info(f"Unknown searcher type: {searcher_type}")

    return searcher