|
import argparse, json, math, os |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
|
|
parser = argparse.ArgumentParser(description = "Split .safetensors file into shards") |
|
parser.add_argument("input_file", type = str, help = "Path to input file") |
|
parser.add_argument("shard_size", type = int, help = "Shard size in megabytes") |
|
args = parser.parse_args() |
|
|
|
input_file = args.input_file |
|
input_base, _ = os.path.splitext(input_file) |
|
shard_size = args.shard_size * 1024**2 |
|
|
|
|
|
|
|
def _tsize(st, key): |
|
|
|
tslice = st.get_slice(key) |
|
shape = tslice.get_shape() |
|
numel = 1 |
|
for x in shape: numel *= x |
|
dtype = tslice.get_dtype() |
|
del tslice |
|
if dtype == "I32": return numel * 4 |
|
elif dtype == "I16": return numel * 2 |
|
elif dtype == "F16": return numel * 2 |
|
elif dtype == "F32": return numel * 4 |
|
else: raise ValueError("Unexpected datatype: " + key) |
|
|
|
num_files = 0 |
|
current_size = shard_size + 1 |
|
total_size = 0 |
|
tensor_map = [] |
|
|
|
print(f" -- Scanning tensors in {input_file}") |
|
|
|
with safe_open(input_file, framework = "pt", device = "cpu") as f: |
|
|
|
for key in f.keys(): |
|
|
|
tensor_size = _tsize(f, key) |
|
total_size += tensor_size |
|
|
|
if current_size + tensor_size > shard_size: |
|
|
|
num_files += 1 |
|
current_size = 0 |
|
current_list = [] |
|
tensor_map.append(current_list) |
|
|
|
current_size += tensor_size |
|
current_list.append(key) |
|
|
|
|
|
|
|
weight_map = {} |
|
|
|
for file_index, keys in enumerate(tensor_map): |
|
|
|
shard = {} |
|
shard_filename = f"{input_base}-{file_index + 1:05}-of-{num_files:05}.safetensors" |
|
|
|
with safe_open(input_file, framework = "pt", device = "cpu") as f: |
|
for key in keys: |
|
print(f" -- Reading: {key}") |
|
shard[key] = f.get_tensor(key) |
|
weight_map[key] = shard_filename |
|
|
|
print(f" -- Writing: {shard_filename}") |
|
save_file(shard, shard_filename) |
|
|
|
|
|
|
|
index = { "metadata": { "total_size": total_size }, "weight_map": weight_map } |
|
index_filename = f"{input_file}.index.json" |
|
|
|
print(f" -- Writing: {index_filename}") |
|
|
|
with open(index_filename, 'w') as f: |
|
json.dump(index, f, indent = 2) |
|
|
|
|
|
|
|
print(f" -- Done") |