Update rp.py
Browse files
rp.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from safetensors.torch import load_file, save_file
|
2 |
import torch
|
|
|
3 |
import torch.nn.functional as F
|
4 |
from tqdm import tqdm # Ensure tqdm is installed
|
5 |
|
@@ -38,18 +39,18 @@ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
|
|
38 |
return merged
|
39 |
|
40 |
if __name__ == "__main__":
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
model2 = load_model(model2_path)
|
50 |
|
51 |
-
|
52 |
-
merged_model = merge_checkpoints(model1, model2, blend_ratio)
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
1 |
from safetensors.torch import load_file, save_file
|
2 |
import torch
|
3 |
+
torch.cuda.empty_cache()
|
4 |
import torch.nn.functional as F
|
5 |
from tqdm import tqdm # Ensure tqdm is installed
|
6 |
|
|
|
39 |
return merged
|
40 |
|
41 |
if __name__ == "__main__":
|
42 |
+
try:
|
43 |
+
model1_path = "flux1-dev.safetensors.1"
|
44 |
+
model2_path = "brainflux_v10.safetensors"
|
45 |
+
blend_ratio = 0.4
|
46 |
+
output_file = "output_checkpoint.safetensors"
|
47 |
|
48 |
+
model1 = load_model(model1_path)
|
49 |
+
model2 = load_model(model2_path)
|
|
|
50 |
|
51 |
+
merged_model = merge_checkpoints(model1, model2, blend_ratio)
|
|
|
52 |
|
53 |
+
save_model(merged_model, output_file)
|
54 |
+
|
55 |
+
except Exception as e:
|
56 |
+
print(f"An error occurred: {e}")
|