File size: 3,770 Bytes
95c7f67
 
5c67222
058f52d
ef5bc80
058f52d
95c7f67
 
 
 
 
 
 
 
 
 
 
058f52d
 
 
 
 
 
 
 
 
 
 
 
734cd8a
058f52d
 
 
 
 
 
734cd8a
058f52d
 
 
 
 
 
 
 
 
 
 
 
 
 
734cd8a
 
 
058f52d
 
734cd8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95c7f67
 
 
 
 
 
058f52d
ef5bc80
734cd8a
 
 
 
95c7f67
 
ef5bc80
 
058f52d
95c7f67
ef5bc80
5c67222
95c7f67
ef5bc80
95c7f67
 
 
ef5bc80
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import requests
from safetensors.torch import load_file, save_file
import torch
torch.cuda.empty_cache()
import torch.nn.functional as F
from tqdm import tqdm

def download_file(url, dest_path):
    print(f"Downloading {url} to {dest_path}")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(dest_path, 'wb') as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
    else:
        raise Exception(f"Failed to download file from {url}")

def load_model(file_path):
    return load_file(file_path)

def save_model(merged_model, output_file):
    print(f"Saving merged model to {output_file}")
    save_file(merged_model, output_file)

def resize_tensor_shapes(tensor1, tensor2):
    if tensor1.size() == tensor2.size():
        return tensor1, tensor2

    # Resize tensor2 to match tensor1's size (Base size)
    max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
    tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
    tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))

    return tensor1_resized, tensor2_resized

def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.75):
    print(f"Merging checkpoints with blend ratio: {blend_ratio}")
    merged = {}
    all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))

    for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
        t1, t2 = ckpt1.get(key), ckpt2.get(key)
        if t1 is not None and t2 is not None:
            t1, t2 = resize_tensor_shapes(t1, t2)
            merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
        elif t1 is not None:
            merged[key] = t1
        else:
            merged[key] = t2

    # Control the final size to be approximately 26 GB
    control_output_size(merged, target_size_gb=26)

    return merged

def control_output_size(merged, target_size_gb):
    # Estimate the size in bytes
    target_size_bytes = target_size_gb * 1024**3  # Convert GB to bytes
    current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())

    # If the current size exceeds the target, truncate the tensors
    if current_size_bytes > target_size_bytes:
        excess_size = current_size_bytes - target_size_bytes
        print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
        
        # Adjusting the tensors to meet the target size
        for key in merged.keys():
            tensor = merged[key]
            # Calculate how much we can reduce
            reduce_size = excess_size // tensor.element_size()  # Number of elements to reduce
            if tensor.numel() > reduce_size:
                # Truncate the tensor
                merged[key] = tensor.flatten()[:tensor.numel() - reduce_size].view(tensor.shape)

def cleanup_files(*file_paths):
    for file_path in file_paths:
        if os.path.exists(file_path):
            os.remove(file_path)
            print(f"Deleted {file_path}")

if __name__ == "__main__":
    try:
        model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
        model2_path = "output_checkpoint.safetensors"
        blend_ratio = 0.75  # Adjust ratio based on requirement
        output_file = "output_checkpoints.safetensors"

        # Loading models
        model1 = load_model(model1_path)
        model2 = load_model(model2_path)

        # Merging models
        merged_model = merge_checkpoints(model1, model2, blend_ratio)

        # Saving merged model
        save_model(merged_model, output_file)

        # Cleaning up downloaded files
        cleanup_files(model1_path, model2_path)
        
    except Exception as e:
        print(f"An error occurred: {e}")