File size: 1,898 Bytes
5c67222
058f52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c67222
058f52d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from safetensors.torch import load_file, save_file
import torch
import torch.nn.functional as F
from tqdm import tqdm  # Ensure tqdm is installed

def load_model(file_path):
    return load_file(file_path)

def save_model(merged_model, output_file):
    print(f"Saving merged model to {output_file}")
    save_file(merged_model, output_file)

def resize_tensor_shapes(tensor1, tensor2):
    if tensor1.size() == tensor2.size():
        return tensor1, tensor2

    max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
    tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
    tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))

    return tensor1_resized, tensor2_resized

def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
    print(f"Merging checkpoints with blend ratio: {blend_ratio}")
    merged = {}
    all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))

    for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
        t1, t2 = ckpt1.get(key), ckpt2.get(key)
        if t1 is not None and t2 is not None:
            t1, t2 = resize_tensor_shapes(t1, t2)
            merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
        elif t1 is not None:
            merged[key] = t1
        else:
            merged[key] = t2

    return merged

if __name__ == "__main__":
    # Set your file paths and blend ratio here
    model1_path = "flux1-dev.safetensors.1"  # Model 1 path
    model2_path = "brainflux_v10.safetensors"  # Model 2 path
    blend_ratio = 0.4  # Blend ratio
    output_file = "output_checkpoint.safetensors"  # Output file name

    # Load the models
    model1 = load_model(model1_path)
    model2 = load_model(model2_path)

    # Merge the models
    merged_model = merge_checkpoints(model1, model2, blend_ratio)

    # Save the merged model
    save_model(merged_model, output_file)