|
import os |
|
import subprocess |
|
import sys |
|
import argparse |
|
from pathlib import Path |
|
from concurrent.futures import ( |
|
ProcessPoolExecutor, |
|
as_completed, |
|
) |
|
from zipnn_compress_file import compress_file |
|
|
|
sys.path.append( |
|
os.path.abspath( |
|
os.path.join( |
|
os.path.dirname(__file__), ".." |
|
) |
|
) |
|
) |
|
|
|
|
|
KB = 1024 |
|
MB = 1024 * 1024 |
|
GB = 1024 * 1024 * 1024 |
|
|
|
RED = "\033[91m" |
|
YELLOW = "\033[93m" |
|
GREEN = "\033[92m" |
|
RESET = "\033[0m" |
|
|
|
|
|
def check_and_install_zipnn(): |
|
try: |
|
import zipnn |
|
except ImportError: |
|
print("zipnn not found. Installing...") |
|
subprocess.check_call( |
|
[ |
|
sys.executable, |
|
"-m", |
|
"pip", |
|
"install", |
|
"zipnn", |
|
"--upgrade", |
|
] |
|
) |
|
import zipnn |
|
|
|
|
|
def parse_streaming_chunk_size( |
|
streaming_chunk_size, |
|
): |
|
if str(streaming_chunk_size).isdigit(): |
|
final = int(streaming_chunk_size) |
|
else: |
|
size_value = int( |
|
streaming_chunk_size[:-2] |
|
) |
|
size_unit = streaming_chunk_size[ |
|
-2 |
|
].lower() |
|
|
|
if size_unit == "k": |
|
final = KB * size_value |
|
elif size_unit == "m": |
|
final = MB * size_value |
|
elif size_unit == "g": |
|
final = GB * size_value |
|
else: |
|
raise ValueError( |
|
f"Invalid size unit: {size_unit}. Use 'k', 'm', or 'g'." |
|
) |
|
|
|
return final |
|
|
|
def replace_in_file(file_path, old: str, new: str) -> None: |
|
"""Given a file_path, replace all occurrences of `old` with `new` inpalce.""" |
|
|
|
with open(file_path, 'r') as file: |
|
file_data = file.read() |
|
|
|
file_data = file_data.replace(old, new) |
|
|
|
with open(file_path, 'w') as file: |
|
file.write(file_data) |
|
|
|
def compress_files_with_suffix( |
|
suffix, |
|
dtype="", |
|
streaming_chunk_size=1048576, |
|
path=".", |
|
delete=False, |
|
r=False, |
|
force=False, |
|
max_processes=1, |
|
hf_cache=False, |
|
model="", |
|
branch="main", |
|
): |
|
import zipnn |
|
|
|
overwrite_first=True |
|
file_list = [] |
|
streaming_chunk_size = ( |
|
parse_streaming_chunk_size( |
|
streaming_chunk_size |
|
) |
|
) |
|
if model: |
|
if not hf_cache: |
|
raise ValueError( |
|
"Must specify --hf_cache when using --model" |
|
) |
|
try: |
|
from huggingface_hub import scan_cache_dir |
|
except ImportError: |
|
raise ImportError( |
|
"huggingface_hub not found. Please pip install huggingface_hub." |
|
) |
|
cache = scan_cache_dir() |
|
repo = next((repo for repo in cache.repos if repo.repo_id == model), None) |
|
|
|
if repo is not None: |
|
print(f"Found repo {model} in cache") |
|
|
|
|
|
hash = '' |
|
try: |
|
with open(os.path.join(repo.repo_path, 'refs', branch), "r") as ref: |
|
hash = ref.read() |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"Branch {branch} not found in repo {model}") |
|
|
|
path = os.path.join(repo.repo_path, 'snapshots', hash) |
|
|
|
directories_to_search = ( |
|
os.walk(path) |
|
if r |
|
else [(path, [], os.listdir(path))] |
|
) |
|
files_found = False |
|
for root, _, files in directories_to_search: |
|
for file_name in files: |
|
if file_name.endswith(suffix): |
|
compressed_path = ( |
|
file_name + ".znn" |
|
) |
|
if not force and os.path.exists( |
|
compressed_path |
|
): |
|
|
|
if overwrite_first: |
|
overwrite_first=False |
|
user_input = ( |
|
input( |
|
f"Compressed files already exists; Would you like to overwrite them all (y/n)? " |
|
) |
|
.strip() |
|
.lower() |
|
) |
|
if user_input not in ( |
|
"y", |
|
"yes", |
|
): |
|
print( |
|
f"No forced overwriting." |
|
) |
|
else: |
|
print( |
|
f"Overwriting all compressed files." |
|
) |
|
force=True |
|
|
|
if not force: |
|
user_input = ( |
|
input( |
|
f"{compressed_path} already exists; overwrite (y/n)? " |
|
) |
|
.strip() |
|
.lower() |
|
) |
|
if user_input not in ( |
|
"y", |
|
"yes", |
|
): |
|
print( |
|
f"Skipping {file_name}..." |
|
) |
|
continue |
|
files_found = True |
|
full_path = os.path.join( |
|
root, file_name |
|
) |
|
file_list.append(full_path) |
|
|
|
if file_list and hf_cache: |
|
try: |
|
from transformers.utils import ( |
|
SAFE_WEIGHTS_INDEX_NAME, |
|
WEIGHTS_INDEX_NAME |
|
) |
|
except ImportError: |
|
raise ImportError( |
|
"Transformers not found. Please pip install transformers." |
|
) |
|
|
|
if os.path.exists(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)): |
|
print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") |
|
blob_name = os.path.join(path, os.readlink(os.path.join(path, SAFE_WEIGHTS_INDEX_NAME))) |
|
replace_in_file( |
|
file_path=blob_name, |
|
old=f"{suffix}", |
|
new=f"{suffix}.znn" |
|
) |
|
elif os.path.exists(os.path.join(path, WEIGHTS_INDEX_NAME)): |
|
print(f"{YELLOW}Fixing Hugging Face model json...{RESET}") |
|
blob_name = os.path.join(path, os.readlink(os.path.join(path, WEIGHTS_INDEX_NAME))) |
|
replace_in_file( |
|
file_path=blob_name, |
|
old=f"{suffix}", |
|
new=f"{suffix}.znn" |
|
) |
|
|
|
with ProcessPoolExecutor( |
|
max_workers=max_processes |
|
) as executor: |
|
future_to_file = { |
|
executor.submit( |
|
compress_file, |
|
file, |
|
dtype, |
|
streaming_chunk_size, |
|
delete, |
|
True, |
|
hf_cache, |
|
): file |
|
for file in file_list[:max_processes] |
|
} |
|
file_list = file_list[max_processes:] |
|
while future_to_file: |
|
for future in as_completed( |
|
future_to_file |
|
): |
|
file = future_to_file.pop(future) |
|
|
|
try: |
|
future.result() |
|
except Exception as exc: |
|
print( |
|
f"{RED}File {file} generated an exception: {exc}{RESET}" |
|
) |
|
|
|
if file_list: |
|
next_file = file_list.pop(0) |
|
future_to_file[ |
|
executor.submit( |
|
compress_file, |
|
next_file, |
|
dtype, |
|
streaming_chunk_size, |
|
delete, |
|
True, |
|
hf_cache, |
|
) |
|
] = next_file |
|
|
|
if not files_found: |
|
print( |
|
f"{RED}No files with the suffix '{suffix}' found.{RESET}" |
|
) |
|
|
|
print(f"{GREEN}All files compressed{RESET}") |
|
|
|
|
|
if __name__ == "__main__": |
|
if len(sys.argv) < 2: |
|
print( |
|
"Usage: python compress_files.py <suffix>" |
|
) |
|
print( |
|
"Example: python compress_files.py 'safetensors'" |
|
) |
|
sys.exit(1) |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Enter a suffix to compress, (optional) dtype, (optional) streaming chunk size, (optional) path to files." |
|
) |
|
parser.add_argument( |
|
"suffix", |
|
type=str, |
|
help="Specify the file suffix to compress all files with that suffix. If a single file name is provided, only that file will be compressed.", |
|
) |
|
parser.add_argument( |
|
"--float32", |
|
action="store_true", |
|
help="A flag that triggers float32 compression", |
|
) |
|
parser.add_argument( |
|
"--streaming_chunk_size", |
|
type=str, |
|
help="An optional streaming chunk size. The format is int (for size in Bytes) or int+KB/MB/GB. Default is 1MB", |
|
) |
|
parser.add_argument( |
|
"--path", |
|
type=str, |
|
help="Path to files to compress", |
|
) |
|
parser.add_argument( |
|
"--delete", |
|
action="store_true", |
|
help="A flag that triggers deletion of a single file instead of compression", |
|
) |
|
parser.add_argument( |
|
"-r", |
|
action="store_true", |
|
help="A flag that triggers recursive search on all subdirectories", |
|
) |
|
parser.add_argument( |
|
"--recursive", |
|
action="store_true", |
|
help="A flag that triggers recursive search on all subdirectories", |
|
) |
|
parser.add_argument( |
|
"--force", |
|
action="store_true", |
|
help="A flag that forces overwriting when compressing.", |
|
) |
|
parser.add_argument( |
|
"--max_processes", |
|
type=int, |
|
help="The amount of maximum processes.", |
|
) |
|
parser.add_argument( |
|
"--hf_cache", |
|
action="store_true", |
|
help="A flag that indicates if the file is in the Hugging Face cache. Must either specify --model or --path to the model's snapshot cache.", |
|
) |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
help="Only when using --hf_cache, specify the model name or path. E.g. 'ibm-granite/granite-7b-instruct'", |
|
) |
|
parser.add_argument( |
|
"--model_branch", |
|
type=str, |
|
default="main", |
|
help="Only when using --model, specify the model branch. Default is 'main'", |
|
) |
|
args = parser.parse_args() |
|
optional_kwargs = {} |
|
if args.float32: |
|
optional_kwargs["dtype"] = 32 |
|
if args.streaming_chunk_size is not None: |
|
optional_kwargs[ |
|
"streaming_chunk_size" |
|
] = args.streaming_chunk_size |
|
if args.path is not None: |
|
optional_kwargs["path"] = args.path |
|
if args.delete: |
|
optional_kwargs["delete"] = args.delete |
|
if args.r or args.recursive: |
|
optional_kwargs["r"] = args.r |
|
if args.force: |
|
optional_kwargs["force"] = args.force |
|
if args.max_processes: |
|
optional_kwargs["max_processes"] = ( |
|
args.max_processes |
|
) |
|
if args.hf_cache: |
|
optional_kwargs["hf_cache"] = args.hf_cache |
|
if args.model: |
|
optional_kwargs["model"] = args.model |
|
if args.model_branch: |
|
optional_kwargs[ |
|
"branch" |
|
] = args.model_branch |
|
|
|
check_and_install_zipnn() |
|
compress_files_with_suffix( |
|
args.suffix, **optional_kwargs |
|
) |
|
|