File size: 609 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from wenet.ssl.bestrq.bestrq_model import BestRQModel
from wenet.ssl.wav2vec2.wav2vec2_model import Wav2vec2Model
from wenet.ssl.w2vbert.w2vbert_model import W2VBERTModel

WENET_SSL_MODEL_CLASS = {
    "w2vbert_model": W2VBERTModel,
    "wav2vec_model": Wav2vec2Model,
    "bestrq_model": BestRQModel
}


def init_model(configs, encoder):

    assert 'model' in configs
    model_type = configs['model']
    assert model_type in WENET_SSL_MODEL_CLASS.keys()
    model = WENET_SSL_MODEL_CLASS[model_type](encoder=encoder,
                                              **configs['model_conf'])
    return model