|
from safetensors.torch import load_file, save_file |
|
import torch |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
|
|
def load_model(file_path): |
|
return load_file(file_path) |
|
|
|
def save_model(merged_model, output_file): |
|
print(f"Saving merged model to {output_file}") |
|
save_file(merged_model, output_file) |
|
|
|
def resize_tensor_shapes(tensor1, tensor2): |
|
if tensor1.size() == tensor2.size(): |
|
return tensor1, tensor2 |
|
|
|
max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)] |
|
tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1))) |
|
tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1))) |
|
|
|
return tensor1_resized, tensor2_resized |
|
|
|
def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5): |
|
print(f"Merging checkpoints with blend ratio: {blend_ratio}") |
|
merged = {} |
|
all_keys = set(ckpt1.keys()).union(set(ckpt2.keys())) |
|
|
|
for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"): |
|
t1, t2 = ckpt1.get(key), ckpt2.get(key) |
|
if t1 is not None and t2 is not None: |
|
t1, t2 = resize_tensor_shapes(t1, t2) |
|
merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2 |
|
elif t1 is not None: |
|
merged[key] = t1 |
|
else: |
|
merged[key] = t2 |
|
|
|
return merged |
|
|
|
if __name__ == "__main__": |
|
|
|
model1_path = "flux1-dev.safetensors.1" |
|
model2_path = "brainflux_v10.safetensors" |
|
blend_ratio = 0.4 |
|
output_file = "output_checkpoint.safetensors" |
|
|
|
|
|
model1 = load_model(model1_path) |
|
model2 = load_model(model2_path) |
|
|
|
|
|
merged_model = merge_checkpoints(model1, model2, blend_ratio) |
|
|
|
|
|
save_model(merged_model, output_file) |