Update model_hf.py
Browse files- model_hf.py +6 -5
model_hf.py
CHANGED
@@ -9,6 +9,7 @@ from torch import Tensor
|
|
9 |
from .config_ssl import SSLConfig
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from transformers import Wav2Vec2ForPreTraining
|
|
|
12 |
|
13 |
___author__ = "Hemlata Tak"
|
14 |
__email__ = "[email protected]"
|
@@ -23,11 +24,11 @@ class SSLModel(nn.Module):
|
|
23 |
super(SSLModel, self).__init__()
|
24 |
# eliminate fairseq dependency
|
25 |
# facebook/wav2vec2-xls-r-300m
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
self.model = model
|
31 |
self.model_device=device
|
32 |
self.out_dim = 1024
|
33 |
return
|
|
|
9 |
from .config_ssl import SSLConfig
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from transformers import Wav2Vec2ForPreTraining
|
12 |
+
import fairseq
|
13 |
|
14 |
___author__ = "Hemlata Tak"
|
15 |
__email__ = "[email protected]"
|
|
|
24 |
super(SSLModel, self).__init__()
|
25 |
# eliminate fairseq dependency
|
26 |
# facebook/wav2vec2-xls-r-300m
|
27 |
+
repo_id = "ash56/ssl-aasist"
|
28 |
+
fname = "xlsr2_300m.pt"
|
29 |
+
cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
|
30 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
|
31 |
+
self.model = model[0]
|
32 |
self.model_device=device
|
33 |
self.out_dim = 1024
|
34 |
return
|