pranavajay commited on
Commit
734cd8a
·
verified ·
1 Parent(s): 546e795

Update A.py

Browse files
Files changed (1) hide show
  1. A.py +28 -9
A.py CHANGED
@@ -27,13 +27,14 @@ def resize_tensor_shapes(tensor1, tensor2):
27
  if tensor1.size() == tensor2.size():
28
  return tensor1, tensor2
29
 
 
30
  max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
31
  tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
32
  tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
33
 
34
  return tensor1_resized, tensor2_resized
35
 
36
- def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
37
  print(f"Merging checkpoints with blend ratio: {blend_ratio}")
38
  merged = {}
39
  all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
@@ -48,8 +49,30 @@ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
48
  else:
49
  merged[key] = t2
50
 
 
 
 
51
  return merged
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def cleanup_files(*file_paths):
54
  for file_path in file_paths:
55
  if os.path.exists(file_path):
@@ -58,14 +81,10 @@ def cleanup_files(*file_paths):
58
 
59
  if __name__ == "__main__":
60
  try:
61
-
62
- model1_path = "flux1-dev.safetensors"
63
- model2_path = "brainflux_v10.safetensors"
64
- blend_ratio = 0.4
65
- output_file = "output_checkpoint.safetensors"
66
-
67
- # Downloading files
68
-
69
 
70
  # Loading models
71
  model1 = load_model(model1_path)
 
27
  if tensor1.size() == tensor2.size():
28
  return tensor1, tensor2
29
 
30
+ # Resize tensor2 to match tensor1's size (Base size)
31
  max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
32
  tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
33
  tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
34
 
35
  return tensor1_resized, tensor2_resized
36
 
37
+ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.75):
38
  print(f"Merging checkpoints with blend ratio: {blend_ratio}")
39
  merged = {}
40
  all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
 
49
  else:
50
  merged[key] = t2
51
 
52
+ # Control the final size to be approximately 26 GB
53
+ control_output_size(merged, target_size_gb=26)
54
+
55
  return merged
56
 
57
+ def control_output_size(merged, target_size_gb):
58
+ # Estimate the size in bytes
59
+ target_size_bytes = target_size_gb * 1024**3 # Convert GB to bytes
60
+ current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())
61
+
62
+ # If the current size exceeds the target, truncate the tensors
63
+ if current_size_bytes > target_size_bytes:
64
+ excess_size = current_size_bytes - target_size_bytes
65
+ print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
66
+
67
+ # Adjusting the tensors to meet the target size
68
+ for key in merged.keys():
69
+ tensor = merged[key]
70
+ # Calculate how much we can reduce
71
+ reduce_size = excess_size // tensor.element_size() # Number of elements to reduce
72
+ if tensor.numel() > reduce_size:
73
+ # Truncate the tensor
74
+ merged[key] = tensor.flatten()[:tensor.numel() - reduce_size].view(tensor.shape)
75
+
76
  def cleanup_files(*file_paths):
77
  for file_path in file_paths:
78
  if os.path.exists(file_path):
 
81
 
82
  if __name__ == "__main__":
83
  try:
84
+ model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
85
+ model2_path = "output_checkpoint.safetensors"
86
+ blend_ratio = 0.75 # Adjust ratio based on requirement
87
+ output_file = "output_checkpoints.safetensors"
 
 
 
 
88
 
89
  # Loading models
90
  model1 = load_model(model1_path)