pranavajay commited on
Commit
95c7f67
·
verified ·
1 Parent(s): 59f60bf

Update A.py

Browse files
Files changed (1) hide show
  1. A.py +32 -1
A.py CHANGED
@@ -1,8 +1,20 @@
 
 
1
  from safetensors.torch import load_file, save_file
2
  import torch
3
  torch.cuda.empty_cache()
4
  import torch.nn.functional as F
5
- from tqdm import tqdm # Ensure tqdm is installed
 
 
 
 
 
 
 
 
 
 
6
 
7
  def load_model(file_path):
8
  return load_file(file_path)
@@ -38,19 +50,38 @@ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.5):
38
 
39
  return merged
40
 
 
 
 
 
 
 
41
  if __name__ == "__main__":
42
  try:
 
 
 
43
  model1_path = "flux1-dev.safetensors"
44
  model2_path = "brainflux_v10.safetensors"
45
  blend_ratio = 0.4
46
  output_file = "output_checkpoint.safetensors"
47
 
 
 
 
 
 
48
  model1 = load_model(model1_path)
49
  model2 = load_model(model2_path)
50
 
 
51
  merged_model = merge_checkpoints(model1, model2, blend_ratio)
52
 
 
53
  save_model(merged_model, output_file)
 
 
 
54
 
55
  except Exception as e:
56
  print(f"An error occurred: {e}")
 
1
+ import os
2
+ import requests
3
  from safetensors.torch import load_file, save_file
4
  import torch
5
  torch.cuda.empty_cache()
6
  import torch.nn.functional as F
7
+ from tqdm import tqdm
8
+
9
+ def download_file(url, dest_path):
10
+ print(f"Downloading {url} to {dest_path}")
11
+ response = requests.get(url, stream=True)
12
+ if response.status_code == 200:
13
+ with open(dest_path, 'wb') as f:
14
+ for chunk in response.iter_content(1024):
15
+ f.write(chunk)
16
+ else:
17
+ raise Exception(f"Failed to download file from {url}")
18
 
19
  def load_model(file_path):
20
  return load_file(file_path)
 
50
 
51
  return merged
52
 
53
+ def cleanup_files(*file_paths):
54
+ for file_path in file_paths:
55
+ if os.path.exists(file_path):
56
+ os.remove(file_path)
57
+ print(f"Deleted {file_path}")
58
+
59
  if __name__ == "__main__":
60
  try:
61
+ model1_url = "https://huggingface.co/multimodalart/FLUX.1-dev2pro-full/resolve/main/flux1-dev.safetensors"
62
+ model2_url = "https://huggingface.co/datasets/John6666/flux1-backup-202409/resolve/main/brainflux_v10.safetensors"
63
+
64
  model1_path = "flux1-dev.safetensors"
65
  model2_path = "brainflux_v10.safetensors"
66
  blend_ratio = 0.4
67
  output_file = "output_checkpoint.safetensors"
68
 
69
+ # Downloading files
70
+ download_file(model1_url, model1_path)
71
+ download_file(model2_url, model2_path)
72
+
73
+ # Loading models
74
  model1 = load_model(model1_path)
75
  model2 = load_model(model2_path)
76
 
77
+ # Merging models
78
  merged_model = merge_checkpoints(model1, model2, blend_ratio)
79
 
80
+ # Saving merged model
81
  save_model(merged_model, output_file)
82
+
83
+ # Cleaning up downloaded files
84
+ cleanup_files(model1_path, model2_path)
85
 
86
  except Exception as e:
87
  print(f"An error occurred: {e}")