File size: 1,446 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import torch
from s3prl.util.download import _urls_to_filepaths
from .expert import UpstreamExpert as _UpstreamExpert
def data2vec_custom(ckpt: str, refresh: bool = False, **kwargs):
if ckpt.startswith("http"):
ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
return _UpstreamExpert(ckpt, **kwargs)
def data2vec_local(*args, **kwargs):
return data2vec_custom(*args, **kwargs)
def data2vec_url(*args, **kwargs):
return data2vec_custom(*args, **kwargs)
def data2vec(refresh=False, *args, **kwargs):
"""
The default model - Base
refresh (bool): whether to download ckpt/config again if existed
"""
return data2vec_base_960(refresh=refresh, *args, **kwargs)
def data2vec_base_960(refresh=False, *args, **kwargs):
"""
The Base model
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs[
"ckpt"
] = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/audio_base_ls.pt"
return data2vec_custom(refresh=refresh, *args, **kwargs)
def data2vec_large_ll60k(refresh=False, *args, **kwargs):
"""
The Large model trained on Libri-light 60k hours of data
refresh (bool): whether to download ckpt/config again if existed
"""
kwargs[
"ckpt"
] = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/vox_pretrained.pt"
return data2vec_custom(refresh=refresh, *args, **kwargs)
|