Update A.py
Browse files
A.py
CHANGED
@@ -27,13 +27,14 @@ def resize_tensor_shapes(tensor1, tensor2):
|
|
27 |
if tensor1.size() == tensor2.size():
|
28 |
return tensor1, tensor2
|
29 |
|
|
|
30 |
max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
|
31 |
tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
|
32 |
tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
|
33 |
|
34 |
return tensor1_resized, tensor2_resized
|
35 |
|
36 |
-
def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.
|
37 |
print(f"Merging checkpoints with blend ratio: {blend_ratio}")
|
38 |
merged = {}
|
39 |
all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
|
@@ -48,8 +49,30 @@ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
|
|
48 |
else:
|
49 |
merged[key] = t2
|
50 |
|
|
|
|
|
|
|
51 |
return merged
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def cleanup_files(*file_paths):
|
54 |
for file_path in file_paths:
|
55 |
if os.path.exists(file_path):
|
@@ -58,14 +81,10 @@ def cleanup_files(*file_paths):
|
|
58 |
|
59 |
if __name__ == "__main__":
|
60 |
try:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
output_file = "output_checkpoint.safetensors"
|
66 |
-
|
67 |
-
# Downloading files
|
68 |
-
|
69 |
|
70 |
# Loading models
|
71 |
model1 = load_model(model1_path)
|
|
|
27 |
if tensor1.size() == tensor2.size():
|
28 |
return tensor1, tensor2
|
29 |
|
30 |
+
# Resize tensor2 to match tensor1's size (Base size)
|
31 |
max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
|
32 |
tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
|
33 |
tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
|
34 |
|
35 |
return tensor1_resized, tensor2_resized
|
36 |
|
37 |
+
def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.75):
|
38 |
print(f"Merging checkpoints with blend ratio: {blend_ratio}")
|
39 |
merged = {}
|
40 |
all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
|
|
|
49 |
else:
|
50 |
merged[key] = t2
|
51 |
|
52 |
+
# Control the final size to be approximately 26 GB
|
53 |
+
control_output_size(merged, target_size_gb=26)
|
54 |
+
|
55 |
return merged
|
56 |
|
57 |
+
def control_output_size(merged, target_size_gb):
|
58 |
+
# Estimate the size in bytes
|
59 |
+
target_size_bytes = target_size_gb * 1024**3 # Convert GB to bytes
|
60 |
+
current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())
|
61 |
+
|
62 |
+
# If the current size exceeds the target, truncate the tensors
|
63 |
+
if current_size_bytes > target_size_bytes:
|
64 |
+
excess_size = current_size_bytes - target_size_bytes
|
65 |
+
print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
|
66 |
+
|
67 |
+
# Adjusting the tensors to meet the target size
|
68 |
+
for key in merged.keys():
|
69 |
+
tensor = merged[key]
|
70 |
+
# Calculate how much we can reduce
|
71 |
+
reduce_size = excess_size // tensor.element_size() # Number of elements to reduce
|
72 |
+
if tensor.numel() > reduce_size:
|
73 |
+
# Truncate the tensor
|
74 |
+
merged[key] = tensor.flatten()[:tensor.numel() - reduce_size].view(tensor.shape)
|
75 |
+
|
76 |
def cleanup_files(*file_paths):
|
77 |
for file_path in file_paths:
|
78 |
if os.path.exists(file_path):
|
|
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
try:
|
84 |
+
model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
|
85 |
+
model2_path = "output_checkpoint.safetensors"
|
86 |
+
blend_ratio = 0.75 # Adjust ratio based on requirement
|
87 |
+
output_file = "output_checkpoints.safetensors"
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Loading models
|
90 |
model1 = load_model(model1_path)
|