pranavajay commited on
Commit
058f52d
·
verified ·
1 Parent(s): aa71077

Update rp.py

Browse files
Files changed (1) hide show
  1. rp.py +53 -24
rp.py CHANGED
@@ -1,26 +1,55 @@
1
- import torch
2
  from safetensors.torch import load_file, save_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def reduce_key_size(input_file, output_file, reduction_factor=0.50):
5
- # Load the model
6
- model_data = load_file(input_file)
7
-
8
- # Iterate through all the tensors and reduce their size
9
- for key in model_data.keys():
10
- original_tensor = model_data[key]
11
-
12
- # Calculate the new size
13
- new_size = int(original_tensor.size(0) * (1 - reduction_factor))
14
-
15
- # Resize the tensor (this could vary depending on your requirements)
16
- if new_size > 0: # Ensure new size is positive
17
- reduced_tensor = original_tensor[:new_size]
18
- model_data[key] = reduced_tensor
19
-
20
- # Save the modified model
21
- save_file(model_data, output_file)
22
-
23
- # Usage example
24
- input_file = 'model-00002-of-00002.safetensors' # Replace with your input model file
25
- output_file = 'model-00002-of-00002.safetensors' # Desired output file name
26
- reduce_key_size(input_file, output_file)
 
 
1
  from safetensors.torch import load_file, save_file
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from tqdm import tqdm # Ensure tqdm is installed
5
+
6
+ def load_model(file_path):
7
+ return load_file(file_path)
8
+
9
+ def save_model(merged_model, output_file):
10
+ print(f"Saving merged model to {output_file}")
11
+ save_file(merged_model, output_file)
12
+
13
+ def resize_tensor_shapes(tensor1, tensor2):
14
+ if tensor1.size() == tensor2.size():
15
+ return tensor1, tensor2
16
+
17
+ max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
18
+ tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
19
+ tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
20
+
21
+ return tensor1_resized, tensor2_resized
22
+
23
+ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
24
+ print(f"Merging checkpoints with blend ratio: {blend_ratio}")
25
+ merged = {}
26
+ all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
27
+
28
+ for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
29
+ t1, t2 = ckpt1.get(key), ckpt2.get(key)
30
+ if t1 is not None and t2 is not None:
31
+ t1, t2 = resize_tensor_shapes(t1, t2)
32
+ merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
33
+ elif t1 is not None:
34
+ merged[key] = t1
35
+ else:
36
+ merged[key] = t2
37
+
38
+ return merged
39
+
40
+ if __name__ == "__main__":
41
+ # Set your file paths and blend ratio here
42
+ model1_path = "flux1-dev.safetensors.1" # Model 1 path
43
+ model2_path = "brainflux_v10.safetensors" # Model 2 path
44
+ blend_ratio = 0.4 # Blend ratio
45
+ output_file = "output_checkpoint.safetensors" # Output file name
46
+
47
+ # Load the models
48
+ model1 = load_model(model1_path)
49
+ model2 = load_model(model2_path)
50
+
51
+ # Merge the models
52
+ merged_model = merge_checkpoints(model1, model2, blend_ratio)
53
 
54
+ # Save the merged model
55
+ save_model(merged_model, output_file)