zhengr's picture
init
c02bdcd
raw
history blame
6.71 kB
import os
from pathlib import Path
import hashlib
import requests
from io import BytesIO
from typing import Dict, Tuple, Optional
from mmap import mmap, ACCESS_READ
from .log import logger
def sha256(fileno: int) -> str:
data = mmap(fileno, 0, access=ACCESS_READ)
h = hashlib.sha256(data).hexdigest()
del data
return h
def check_model(
dir_name: Path, model_name: str, hash: str, remove_incorrect=False
) -> bool:
target = dir_name / model_name
relname = target.as_posix()
logger.get_logger().debug(f"checking {relname}...")
if not os.path.exists(target):
logger.get_logger().info(f"{target} not exist.")
return False
with open(target, "rb") as f:
digest = sha256(f.fileno())
bakfile = f"{target}.bak"
if digest != hash:
logger.get_logger().warning(f"{target} sha256 hash mismatch.")
logger.get_logger().info(f"expected: {hash}")
logger.get_logger().info(f"real val: {digest}")
if remove_incorrect:
if not os.path.exists(bakfile):
os.rename(str(target), bakfile)
else:
os.remove(str(target))
return False
if remove_incorrect and os.path.exists(bakfile):
os.remove(bakfile)
return True
def check_folder(
base_dir: Path,
*innder_dirs: str,
names: Tuple[str],
sha256_map: Dict[str, str],
update=False,
) -> bool:
key = "sha256_"
current_dir = base_dir
for d in innder_dirs:
current_dir /= d
key += f"{d}_"
for model in names:
menv = model.replace(".", "_")
if not check_model(current_dir, model, sha256_map[f"{key}{menv}"], update):
return False
return True
def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool:
logger.get_logger().info("checking assets...")
if not check_folder(
base_dir,
"asset",
names=(
"Decoder.pt",
"DVAE_full.pt",
"Embed.safetensors",
"Vocos.pt",
),
sha256_map=sha256_map,
update=update,
):
return False
if not check_folder(
base_dir,
"asset",
"gpt",
names=(
"config.json",
"model.safetensors",
),
sha256_map=sha256_map,
update=update,
):
return False
if not check_folder(
base_dir,
"asset",
"tokenizer",
names=(
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
),
sha256_map=sha256_map,
update=update,
):
return False
logger.get_logger().info("all assets are already latest.")
return True
def download_and_extract_tar_gz(
url: str, folder: str, headers: Optional[Dict[str, str]] = None
):
import tarfile
logger.get_logger().info(f"downloading {url}")
response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
with BytesIO() as out_file:
out_file.write(response.content)
out_file.seek(0)
logger.get_logger().info(f"downloaded.")
with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
tar.extractall(folder)
logger.get_logger().info(f"extracted into {folder}")
def download_and_extract_zip(
url: str, folder: str, headers: Optional[Dict[str, str]] = None
):
import zipfile
logger.get_logger().info(f"downloading {url}")
response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
with BytesIO() as out_file:
out_file.write(response.content)
out_file.seek(0)
logger.get_logger().info(f"downloaded.")
with zipfile.ZipFile(out_file) as zip_ref:
zip_ref.extractall(folder)
logger.get_logger().info(f"extracted into {folder}")
def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]):
logger.get_logger().info(f"downloading {url}")
response = requests.get(url, headers=headers, stream=True, timeout=(100, 3))
with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
out_file.write(response.content)
logger.get_logger().info(f"downloaded into {folder}")
def download_all_assets(tmpdir: str, version="0.2.8"):
import subprocess
import platform
archs = {
"aarch64": "arm64",
"armv8l": "arm64",
"arm64": "arm64",
"x86": "386",
"i386": "386",
"i686": "386",
"386": "386",
"x86_64": "amd64",
"x64": "amd64",
"amd64": "amd64",
}
system_type = platform.system().lower()
architecture = platform.machine().lower()
is_win = system_type == "windows"
architecture = archs.get(architecture, None)
if not architecture:
logger.get_logger().error(f"architecture {architecture} is not supported")
exit(1)
try:
BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"
suffix = "zip" if is_win else "tar.gz"
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
cmdfile = os.path.join(tmpdir, "rvcmd")
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
except Exception:
BASE_URL = (
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/"
)
suffix = "zip" if is_win else "tar.gz"
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
download_dns_yaml(
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/raw/branch/main/dns.yaml",
tmpdir,
headers={
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0"
},
)
cmdfile = os.path.join(tmpdir, "rvcmd")
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run(
[
cmdfile,
"-notui",
"-w",
"0",
"-dns",
os.path.join(tmpdir, "dns.yaml"),
"assets/chtts",
]
)