Mixtral-8x22B-v0.1 / split.py
leafspark's picture
Add safetensors merge and split helper files
e1cc7f1 verified
import os
import math
import json
CHUNK_SIZE = 2 * 1024**3 # 40GB
CHUNK_PATHS_FILE = "chunk_paths.json"
def split(filepath, chunk_size=CHUNK_SIZE):
basename = os.path.basename(filepath)
dirname = os.path.dirname(filepath)
extension = basename.split(".")[-1]
filename_no_ext = basename.split(".")[-2]
file_size = os.path.getsize(filepath)
num_chunks = math.ceil(file_size / chunk_size)
digit_count = len(str(num_chunks))
chunk_paths = []
for i in range(1, num_chunks+1):
start = (i-1) * chunk_size
chunk_filename = f"{filename_no_ext}-{str(i).zfill(digit_count)}-of-{str(num_chunks).zfill(digit_count)}.{extension}"
split_path = os.path.join(dirname, chunk_filename)
with open(filepath, "rb") as f_in:
f_in.seek(start)
chunk = f_in.read(chunk_size)
with open(split_path, "wb") as f_out:
f_out.write(chunk)
chunk_paths.append(split_path)
with open(CHUNK_PATHS_FILE, 'w') as f:
json.dump(chunk_paths, f)
return chunk_paths
main_filepath = "consolidated.safetensors" # File to be split
chunk_paths = split(main_filepath)