File size: 1,854 Bytes
5c67222 058f52d ef5bc80 058f52d ef5bc80 058f52d ef5bc80 058f52d ef5bc80 5c67222 ef5bc80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from safetensors.torch import load_file, save_file
import torch
torch.cuda.empty_cache()
import torch.nn.functional as F
from tqdm import tqdm # Ensure tqdm is installed
def load_model(file_path):
return load_file(file_path)
def save_model(merged_model, output_file):
print(f"Saving merged model to {output_file}")
save_file(merged_model, output_file)
def resize_tensor_shapes(tensor1, tensor2):
if tensor1.size() == tensor2.size():
return tensor1, tensor2
max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
return tensor1_resized, tensor2_resized
def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
print(f"Merging checkpoints with blend ratio: {blend_ratio}")
merged = {}
all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
t1, t2 = ckpt1.get(key), ckpt2.get(key)
if t1 is not None and t2 is not None:
t1, t2 = resize_tensor_shapes(t1, t2)
merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
elif t1 is not None:
merged[key] = t1
else:
merged[key] = t2
return merged
if __name__ == "__main__":
try:
model1_path = "flux1-dev.safetensors.1"
model2_path = "brainflux_v10.safetensors"
blend_ratio = 0.4
output_file = "output_checkpoint.safetensors"
model1 = load_model(model1_path)
model2 = load_model(model2_path)
merged_model = merge_checkpoints(model1, model2, blend_ratio)
save_model(merged_model, output_file)
except Exception as e:
print(f"An error occurred: {e}") |