roberta-base-ZipNN-Compressed / zipnn_decompress_path.py
royleibov's picture
add znn
62be76c
import os
import sys
import argparse
import subprocess
from pathlib import Path
from concurrent.futures import (
ProcessPoolExecutor,
as_completed,
)
from zipnn_decompress_file import (
decompress_file,
)
sys.path.append(
os.path.abspath(
os.path.join(
os.path.dirname(__file__),
"..",
)
)
)
RED = "\033[91m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
RESET = "\033[0m"
def check_and_install_zipnn():
try:
import zipnn
except ImportError:
print("zipnn not found. Installing...")
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"zipnn",
]
)
import zipnn
def replace_in_file(file_path, old: str, new: str) -> None:
"""Given a file_path, replace all occurrences of `old` with `new` inpalce."""
with open(file_path, 'r') as file:
file_data = file.read()
file_data = file_data.replace(old, new)
with open(file_path, 'w') as file:
file.write(file_data)
def decompress_znn_files(
path=".",
delete=False,
force=False,
max_processes=1,
hf_cache=False,
model="",
branch="main",
):
import zipnn
overwrite_first=True
if model:
if not hf_cache:
raise ValueError(
"Must specify --hf_cache when using --model"
)
try:
from huggingface_hub import scan_cache_dir
except ImportError:
raise ImportError(
"huggingface_hub not found. Please pip install huggingface_hub."
)
cache = scan_cache_dir()
repo = next((repo for repo in cache.repos if repo.repo_id == model), None)
if repo is not None:
print(f"Found repo {model} in cache")
# Get the latest revision path
hash = ''
try:
with open(os.path.join(repo.repo_path, 'refs', branch), "r") as ref:
hash = ref.read()
except FileNotFoundError:
raise FileNotFoundError(f"Branch {branch} not found in repo {model}")
path = os.path.join(repo.repo_path, 'snapshots', hash)
file_list = []
directories_to_search = [
(
path,
[],
os.listdir(path),
)
]
for (
root,
_,
files,
) in directories_to_search:
for file_name in files:
if file_name.endswith(".znn"):
decompressed_path = file_name[:-4]
if not force and os.path.exists(
decompressed_path
):
#
if overwrite_first:
overwrite_first=False
user_input = (
input(
f"Decompressed files already exists; Would you like to overwrite them all (y/n)? "
)
.strip()
.lower()
)
if user_input not in (
"y",
"yes",
):
print(
f"No forced overwriting."
)
else:
print(
f"Overwriting all decompressed files."
)
force=True
#
if not force:
user_input = (
input(
f"{decompressed_path} already exists; overwrite (y/n)? "
)
.strip()
.lower()
)
if user_input not in (
"y",
"yes",
):
print(
f"Skipping {file_name}..."
)
continue
full_path = os.path.join(
root,
file_name,
)
file_list.append(full_path)
if file_list and hf_cache:
try:
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME
)
except ImportError:
raise ImportError(
"Transformers not found. Please pip install transformers."
)
suffix = file_list[0].split('/')[-1].split('.')[-2] # get the one before .znn
if os.path.exists(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)):
print(f"{YELLOW}Fixing Hugging Face model json...{RESET}")
blob_name = os.path.join(path, os.readlink(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)))
replace_in_file(
file_path=blob_name,
old=f"{suffix}.znn",
new=f"{suffix}"
)
elif os.path.exists(os.path.join(path, WEIGHTS_INDEX_NAME)):
print(f"{YELLOW}Fixing Hugging Face model json...{RESET}")
blob_name = os.path.join(path, os.readlink(os.path.join(path, WEIGHTS_INDEX_NAME)))
replace_in_file(
file_path=blob_name,
old=f"{suffix}.znn",
new=f"{suffix}"
)
with ProcessPoolExecutor(
max_workers=max_processes
) as executor:
for file in file_list[:max_processes]:
future_to_file = {
executor.submit(
decompress_file,
file,
delete,
True,
hf_cache,
): file
for file in file_list[
:max_processes
]
}
file_list = file_list[max_processes:]
while future_to_file:
for future in as_completed(
future_to_file
):
file = future_to_file.pop(
future
)
try:
future.result()
except Exception as exc:
print(
f"{RED}File {file} generated an exception: {exc}{RESET}"
)
if file_list:
next_file = file_list.pop(
0
)
future_to_file[
executor.submit(
decompress_file,
next_file,
delete,
True,
hf_cache,
)
] = next_file
#
print(f"{GREEN}All files decompressed{RESET}")
if __name__ == "__main__":
check_and_install_zipnn()
parser = argparse.ArgumentParser(
description="Compresses all .znn files."
)
parser.add_argument(
"--path",
type=str,
help="Path to folder of files to decompress. If left empty, checks current folder.",
)
parser.add_argument(
"--delete",
action="store_true",
help="A flag that triggers deletion of a single compressed file instead of decompression",
)
parser.add_argument(
"--force",
action="store_true",
help="A flag that forces overwriting when decompressing.",
)
parser.add_argument(
"--max_processes",
type=int,
help="The amount of maximum processes.",
)
parser.add_argument(
"--hf_cache",
action="store_true",
help="A flag that indicates if the file is in the Hugging Face cache. Must either specify --model or --path to the model's snapshot cache.",
)
parser.add_argument(
"--model",
type=str,
help="Only when using --hf_cache, specify the model name or path. E.g. 'ibm-granite/granite-7b-instruct'",
)
parser.add_argument(
"--model_branch",
type=str,
default="main",
help="Only when using --model, specify the model branch. Default is 'main'",
)
args = parser.parse_args()
optional_kwargs = {}
if args.path is not None:
optional_kwargs["path"] = args.path
if args.delete:
optional_kwargs["delete"] = args.delete
if args.force:
optional_kwargs["force"] = args.force
if args.max_processes:
optional_kwargs["max_processes"] = (
args.max_processes
)
if args.hf_cache:
optional_kwargs["hf_cache"] = args.hf_cache
if args.model:
optional_kwargs["model"] = args.model
if args.model_branch:
optional_kwargs[
"branch"
] = args.model_branch
decompress_znn_files(**optional_kwargs)