File size: 3,991 Bytes
95c7f67 5c67222 058f52d ef5bc80 058f52d 95c7f67 058f52d 734cd8a 058f52d 8e7098c 058f52d 8e7098c 734cd8a 058f52d 734cd8a 8e7098c 734cd8a 8e7098c 734cd8a 95c7f67 058f52d ef5bc80 734cd8a 8e7098c 95c7f67 ef5bc80 058f52d 95c7f67 ef5bc80 5c67222 95c7f67 ef5bc80 95c7f67 936aca5 ef5bc80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import requests
from safetensors.torch import load_file, save_file
import torch
torch.cuda.empty_cache()
import torch.nn.functional as F
from tqdm import tqdm
def download_file(url, dest_path):
print(f"Downloading {url} to {dest_path}")
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(1024):
f.write(chunk)
else:
raise Exception(f"Failed to download file from {url}")
def load_model(file_path):
return load_file(file_path)
def save_model(merged_model, output_file):
print(f"Saving merged model to {output_file}")
save_file(merged_model, output_file)
def resize_tensor_shapes(tensor1, tensor2):
if tensor1.size() == tensor2.size():
return tensor1, tensor2
# Resize tensor2 to match tensor1's size (Base size)
max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
return tensor1_resized, tensor2_resized
def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.6):
print(f"Merging checkpoints with blend ratio: {blend_ratio}")
merged = {}
all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
t1, t2 = ckpt1.get(key), ckpt2.get(key)
if t1 is not None and t2 is not None:
t1, t2 = resize_tensor_shapes(t1, t2)
merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
elif t1 is not None:
merged[key] = t1
else:
merged[key] = t2
# Control the final size to be strictly 26 GB
control_output_size(merged, target_size_gb=26)
return merged
def control_output_size(merged, target_size_gb):
# Estimate the size in bytes
target_size_bytes = target_size_gb * 1024**3 # Convert GB to bytes
current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())
# If the current size exceeds the target, truncate the tensors
if current_size_bytes > target_size_bytes:
excess_size = current_size_bytes - target_size_bytes
print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
# Calculate the total number of elements to reduce
elements_to_reduce = excess_size // 4 # Assuming 4 bytes per float32 tensor
total_elements = sum(tensor.numel() for tensor in merged.values())
# Distribute the reduction uniformly across all tensors
for key in merged.keys():
tensor = merged[key]
num_elements = tensor.numel()
# Calculate how much to reduce from this tensor
reduction = min(elements_to_reduce, num_elements)
merged[key] = tensor.flatten()[:num_elements - reduction].view(tensor.shape)
elements_to_reduce -= reduction
if elements_to_reduce <= 0:
break
def cleanup_files(*file_paths):
for file_path in file_paths:
if os.path.exists(file_path):
os.remove(file_path)
print(f"Deleted {file_path}")
if __name__ == "__main__":
try:
model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
model2_path = "output_checkpoint.safetensors"
blend_ratio = 0.6 # Set to 60%
output_file = "output_checkpoint.safetensors"
# Loading models
model1 = load_model(model1_path)
model2 = load_model(model2_path)
# Merging models
merged_model = merge_checkpoints(model1, model2, blend_ratio)
# Saving merged model
save_model(merged_model, output_file)
# Cleaning up downloaded files
cleanup_files(model1_path)
except Exception as e:
print(f"An error occurred: {e}") |