OSUM / wenet /cli /hub.py
tomxxie
适配zeroGPU
568e264
raw
history blame
3.67 kB
# Copyright (c) 2022 Mddct([email protected])
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import requests
import sys
import tarfile
from pathlib import Path
from urllib.request import urlretrieve
import tqdm
def download(url: str, dest: str, only_child=True):
""" download from url to dest
"""
assert os.path.exists(dest)
print('Downloading {} to {}'.format(url, dest))
def progress_hook(t):
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed
return update_to
# *.tar.gz
name = url.split('?')[0].split('/')[-1]
tar_path = os.path.join(dest, name)
with tqdm.tqdm(unit='B',
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=(name)) as t:
urlretrieve(url,
filename=tar_path,
reporthook=progress_hook(t),
data=None)
t.total = t.n
with tarfile.open(tar_path) as f:
if not only_child:
f.extractall(dest)
else:
for tarinfo in f:
if "/" not in tarinfo.name:
continue
name = os.path.basename(tarinfo.name)
fileobj = f.extractfile(tarinfo)
with open(os.path.join(dest, name), "wb") as writer:
writer.write(fileobj.read())
class Hub(object):
"""Hub for wenet pretrain runtime model
"""
# TODO(Mddct): make assets class to support other language
Assets = {
# wenetspeech
"chinese": "wenetspeech_u2pp_conformer_libtorch.tar.gz",
# gigaspeech
"english": "gigaspeech_u2pp_conformer_libtorch.tar.gz",
# paraformer
"paraformer": "paraformer.tar.gz"
}
def __init__(self) -> None:
pass
@staticmethod
def get_model_by_lang(lang: str) -> str:
if lang not in Hub.Assets.keys():
print('ERROR: Unsupported language {} !!!'.format(lang))
sys.exit(1)
# NOTE(Mddct): model_dir structure
# Path.Home()/.wenet
# - chs
# - units.txt
# - final.zip
# - en
# - units.txt
# - final.zip
model = Hub.Assets[lang]
model_dir = os.path.join(Path.home(), ".wenet", lang)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# TODO(Mddct): model metadata
if set(["final.zip",
"units.txt"]).issubset(set(os.listdir(model_dir))):
return model_dir
# If not exist, download
response = requests.get(
"https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa
)
model_info = next(data for data in response.json()["Data"]
if data["Key"] == model)
model_url = model_info['Url']
download(model_url, model_dir, only_child=True)
return model_dir