File size: 3,991 Bytes
95c7f67
 
5c67222
058f52d
ef5bc80
058f52d
95c7f67
 
 
 
 
 
 
 
 
 
 
058f52d
 
 
 
 
 
 
 
 
 
 
 
734cd8a
058f52d
 
 
 
 
 
8e7098c
058f52d
 
 
 
 
 
 
 
 
 
 
 
 
 
8e7098c
734cd8a
 
058f52d
 
734cd8a
 
 
 
 
 
 
 
 
 
8e7098c
 
 
 
 
734cd8a
 
8e7098c
 
 
 
 
 
 
734cd8a
95c7f67
 
 
 
 
 
058f52d
ef5bc80
734cd8a
 
8e7098c
 
95c7f67
 
ef5bc80
 
058f52d
95c7f67
ef5bc80
5c67222
95c7f67
ef5bc80
95c7f67
 
936aca5
ef5bc80
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import requests
from safetensors.torch import load_file, save_file
import torch
torch.cuda.empty_cache()
import torch.nn.functional as F
from tqdm import tqdm

def download_file(url, dest_path):
    print(f"Downloading {url} to {dest_path}")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(dest_path, 'wb') as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
    else:
        raise Exception(f"Failed to download file from {url}")

def load_model(file_path):
    return load_file(file_path)

def save_model(merged_model, output_file):
    print(f"Saving merged model to {output_file}")
    save_file(merged_model, output_file)

def resize_tensor_shapes(tensor1, tensor2):
    if tensor1.size() == tensor2.size():
        return tensor1, tensor2

    # Resize tensor2 to match tensor1's size (Base size)
    max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
    tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
    tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))

    return tensor1_resized, tensor2_resized

def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.6):
    print(f"Merging checkpoints with blend ratio: {blend_ratio}")
    merged = {}
    all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))

    for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
        t1, t2 = ckpt1.get(key), ckpt2.get(key)
        if t1 is not None and t2 is not None:
            t1, t2 = resize_tensor_shapes(t1, t2)
            merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
        elif t1 is not None:
            merged[key] = t1
        else:
            merged[key] = t2

    # Control the final size to be strictly 26 GB
    control_output_size(merged, target_size_gb=26)

    return merged

def control_output_size(merged, target_size_gb):
    # Estimate the size in bytes
    target_size_bytes = target_size_gb * 1024**3  # Convert GB to bytes
    current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())

    # If the current size exceeds the target, truncate the tensors
    if current_size_bytes > target_size_bytes:
        excess_size = current_size_bytes - target_size_bytes
        print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
        
        # Calculate the total number of elements to reduce
        elements_to_reduce = excess_size // 4  # Assuming 4 bytes per float32 tensor
        total_elements = sum(tensor.numel() for tensor in merged.values())

        # Distribute the reduction uniformly across all tensors
        for key in merged.keys():
            tensor = merged[key]
            num_elements = tensor.numel()
            # Calculate how much to reduce from this tensor
            reduction = min(elements_to_reduce, num_elements)
            merged[key] = tensor.flatten()[:num_elements - reduction].view(tensor.shape)
            elements_to_reduce -= reduction
            if elements_to_reduce <= 0:
                break

def cleanup_files(*file_paths):
    for file_path in file_paths:
        if os.path.exists(file_path):
            os.remove(file_path)
            print(f"Deleted {file_path}")

if __name__ == "__main__":
    try:
        model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
        model2_path = "output_checkpoint.safetensors"
        blend_ratio = 0.6  # Set to 60%
        output_file = "output_checkpoint.safetensors"

        # Loading models
        model1 = load_model(model1_path)
        model2 = load_model(model2_path)

        # Merging models
        merged_model = merge_checkpoints(model1, model2, blend_ratio)

        # Saving merged model
        save_model(merged_model, output_file)

        # Cleaning up downloaded files
        cleanup_files(model1_path)
        
    except Exception as e:
        print(f"An error occurred: {e}")