Diffusers
English
sayakpaul HF staff commited on
Commit
68307ce
·
verified ·
1 Parent(s): 4233ccc

Upload folder using huggingface_hub

Browse files
Files changed (39) hide show
  1. .gitattributes +20 -0
  2. How2Draw-V2_000002800_rand_svd.safetensors +3 -0
  3. How2Draw-V2_000002800_reduced.safetensors +3 -0
  4. How2Draw-V2_000002800_reduced_sparse.safetensors +3 -0
  5. How2Draw-V2_000002800_svd.safetensors +3 -0
  6. images/How2Draw-V2_000002800_rand_svd_0.png +3 -0
  7. images/How2Draw-V2_000002800_rand_svd_1.png +0 -0
  8. images/How2Draw-V2_000002800_rand_svd_2.png +3 -0
  9. images/How2Draw-V2_000002800_rand_svd_3.png +0 -0
  10. images/How2Draw-V2_000002800_rand_svd_collage_0.png +3 -0
  11. images/How2Draw-V2_000002800_rand_svd_collage_1.png +3 -0
  12. images/How2Draw-V2_000002800_rand_svd_collage_2.png +3 -0
  13. images/How2Draw-V2_000002800_rand_svd_collage_3.png +3 -0
  14. images/How2Draw-V2_000002800_svd_0.png +3 -0
  15. images/How2Draw-V2_000002800_svd_1.png +0 -0
  16. images/How2Draw-V2_000002800_svd_2.png +3 -0
  17. images/How2Draw-V2_000002800_svd_3.png +0 -0
  18. images/How2Draw-V2_000002800_svd_collage_0.png +3 -0
  19. images/How2Draw-V2_000002800_svd_collage_1.png +3 -0
  20. images/How2Draw-V2_000002800_svd_collage_2.png +3 -0
  21. images/How2Draw-V2_000002800_svd_collage_3.png +3 -0
  22. images/collage_0.png +3 -0
  23. images/collage_1.png +3 -0
  24. images/collage_2.png +3 -0
  25. images/collage_3.png +3 -0
  26. images/original_0.png +3 -0
  27. images/original_1.png +0 -0
  28. images/original_2.png +3 -0
  29. images/original_3.png +0 -0
  30. images/reduced_0.png +3 -0
  31. images/reduced_1.png +0 -0
  32. images/reduced_2.png +3 -0
  33. images/reduced_3.png +0 -0
  34. images/reduced_sparse_0.png +0 -0
  35. images/reduced_sparse_1.png +0 -0
  36. images/reduced_sparse_2.png +0 -0
  37. images/reduced_sparse_3.png +0 -0
  38. low_rank_lora.py +156 -0
  39. svd_low_rank_lora.py +178 -0
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/How2Draw-V2_000002800_rand_svd_0.png filter=lfs diff=lfs merge=lfs -text
37
+ images/How2Draw-V2_000002800_rand_svd_2.png filter=lfs diff=lfs merge=lfs -text
38
+ images/How2Draw-V2_000002800_rand_svd_collage_0.png filter=lfs diff=lfs merge=lfs -text
39
+ images/How2Draw-V2_000002800_rand_svd_collage_1.png filter=lfs diff=lfs merge=lfs -text
40
+ images/How2Draw-V2_000002800_rand_svd_collage_2.png filter=lfs diff=lfs merge=lfs -text
41
+ images/How2Draw-V2_000002800_rand_svd_collage_3.png filter=lfs diff=lfs merge=lfs -text
42
+ images/How2Draw-V2_000002800_svd_0.png filter=lfs diff=lfs merge=lfs -text
43
+ images/How2Draw-V2_000002800_svd_2.png filter=lfs diff=lfs merge=lfs -text
44
+ images/How2Draw-V2_000002800_svd_collage_0.png filter=lfs diff=lfs merge=lfs -text
45
+ images/How2Draw-V2_000002800_svd_collage_1.png filter=lfs diff=lfs merge=lfs -text
46
+ images/How2Draw-V2_000002800_svd_collage_2.png filter=lfs diff=lfs merge=lfs -text
47
+ images/How2Draw-V2_000002800_svd_collage_3.png filter=lfs diff=lfs merge=lfs -text
48
+ images/collage_0.png filter=lfs diff=lfs merge=lfs -text
49
+ images/collage_1.png filter=lfs diff=lfs merge=lfs -text
50
+ images/collage_2.png filter=lfs diff=lfs merge=lfs -text
51
+ images/collage_3.png filter=lfs diff=lfs merge=lfs -text
52
+ images/original_0.png filter=lfs diff=lfs merge=lfs -text
53
+ images/original_2.png filter=lfs diff=lfs merge=lfs -text
54
+ images/reduced_0.png filter=lfs diff=lfs merge=lfs -text
55
+ images/reduced_2.png filter=lfs diff=lfs merge=lfs -text
How2Draw-V2_000002800_rand_svd.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5931bbc8fd3c854596c5f56660649de6a2e9a220cc10cf67f1d4d0c58b6dbd5
3
+ size 43090208
How2Draw-V2_000002800_reduced.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:143d9e2b2b288a25086d034f8f750831dcd4e80c0a2814107bd36d3c80c8f999
3
+ size 43090208
How2Draw-V2_000002800_reduced_sparse.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2081b728f126065155dd11b4d60384967fe839c8a35088f781742258399ee416
3
+ size 43090208
How2Draw-V2_000002800_svd.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f4397c7fba89f7805a7ba30b99034538da4a4369292b64d9e72a3bc3cdb8f3f
3
+ size 43090208
images/How2Draw-V2_000002800_rand_svd_0.png ADDED

Git LFS Details

  • SHA256: d9016fdfd3b4bbcdef223fd28c06ff68aeb3d9a34d195a79570e5981d24907a5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
images/How2Draw-V2_000002800_rand_svd_1.png ADDED
images/How2Draw-V2_000002800_rand_svd_2.png ADDED

Git LFS Details

  • SHA256: ca82168e3c4f8a88126080aaaf4f4f3d7117d2946726cda47cbe8ff4b2e4f0a5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
images/How2Draw-V2_000002800_rand_svd_3.png ADDED
images/How2Draw-V2_000002800_rand_svd_collage_0.png ADDED

Git LFS Details

  • SHA256: 90db872203ac311a835adb29449b4b9fffab00b029b198f4ef3bb229132f6ba9
  • Pointer size: 132 Bytes
  • Size of remote file: 2.34 MB
images/How2Draw-V2_000002800_rand_svd_collage_1.png ADDED

Git LFS Details

  • SHA256: 3d01743c085e2bdcf3fccbcc026fa0b3b7525f0682bcb8e45a7752d15741e571
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
images/How2Draw-V2_000002800_rand_svd_collage_2.png ADDED

Git LFS Details

  • SHA256: a1b7a5336bab67aed13015d6acd3f82c0c4462dc54d496f8725fd1f78966e5bd
  • Pointer size: 132 Bytes
  • Size of remote file: 2.17 MB
images/How2Draw-V2_000002800_rand_svd_collage_3.png ADDED

Git LFS Details

  • SHA256: d01080ba48ed8145a01d173ad948f35b419b72cd9a8630d346412d04e5ceb0db
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
images/How2Draw-V2_000002800_svd_0.png ADDED

Git LFS Details

  • SHA256: 23da01d39ed50d2fdabc06684a61fad0cd249c0a42545b5846deda3da58eae95
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
images/How2Draw-V2_000002800_svd_1.png ADDED
images/How2Draw-V2_000002800_svd_2.png ADDED

Git LFS Details

  • SHA256: e210857072b7d45f55540553b88726dbdbad1027ebe3b2249a8bd2373e08c7ae
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
images/How2Draw-V2_000002800_svd_3.png ADDED
images/How2Draw-V2_000002800_svd_collage_0.png ADDED

Git LFS Details

  • SHA256: ce45e9f7c076ea55b412e4d837fdf32cc69a706e7024091b09ca590d55d49b8d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.34 MB
images/How2Draw-V2_000002800_svd_collage_1.png ADDED

Git LFS Details

  • SHA256: 1fbea3b6665c94bb148c5a1b82c18f004a2c9c591b9dc3bde2cd375edf897412
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
images/How2Draw-V2_000002800_svd_collage_2.png ADDED

Git LFS Details

  • SHA256: 7783a4a054dd395e57c059cc28b56ae611b838fa86a51b3e695f2d3c7c85ab84
  • Pointer size: 132 Bytes
  • Size of remote file: 2.17 MB
images/How2Draw-V2_000002800_svd_collage_3.png ADDED

Git LFS Details

  • SHA256: 059a8c091642ccd04d50befa13a005421d79d7bd00cdbcc175388169152b1220
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
images/collage_0.png ADDED

Git LFS Details

  • SHA256: 8c57ce49491954dc8aa01d8f2cebb2d6679a4c1e0dc8e0506a2fd54773182f6f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
images/collage_1.png ADDED

Git LFS Details

  • SHA256: d1c4757013129f937ef029bb476313854512632dcac3c6dc9d74584184af4ddd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
images/collage_2.png ADDED

Git LFS Details

  • SHA256: a318be854280bb1de9b2fd4071a65698b1c615ebee63cb9b560759295cc43d13
  • Pointer size: 132 Bytes
  • Size of remote file: 2.02 MB
images/collage_3.png ADDED

Git LFS Details

  • SHA256: b05463e06e7fa8b0d257ea7f082b6e48dee66e181e61381be466f879e4090d53
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
images/original_0.png ADDED

Git LFS Details

  • SHA256: f7a38ef30451b255916070bb9a894c223268cf9cd7963d73df3e6f6e73dde841
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
images/original_1.png ADDED
images/original_2.png ADDED

Git LFS Details

  • SHA256: e9afd5a5bca8a4cdbdb6a18cda3541110c59b65adf21d3c9a0dfd77475c92414
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
images/original_3.png ADDED
images/reduced_0.png ADDED

Git LFS Details

  • SHA256: 1c20eadb0f8d5da37e440edfce1e849e4aa9659b6704b7c8a9c81e2bcfcff763
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
images/reduced_1.png ADDED
images/reduced_2.png ADDED

Git LFS Details

  • SHA256: edca94259042c80947d6546ae215ebc3a7b3ebf0d98703bfee3baf6c67244580
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB
images/reduced_3.png ADDED
images/reduced_sparse_0.png ADDED
images/reduced_sparse_1.png ADDED
images/reduced_sparse_2.png ADDED
images/reduced_sparse_3.png ADDED
low_rank_lora.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+
4
+ python low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \
5
+ --new_rank=4 --new_lora_path="How2Draw-V2_000002800_rank_4.safetensors"
6
+ """
7
+
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ import safetensors.torch
11
+ import fire
12
+
13
+
14
+ def sparse_random_projection_matrix(original_rank, new_rank, density=0.1):
15
+ """
16
+ Generates a sparse random projection matrix.
17
+
18
+ Args:
19
+ original_rank (int): Original rank (number of rows).
20
+ new_rank (int): Reduced rank (number of columns).
21
+ density (float): Fraction of non-zero elements.
22
+
23
+ Returns:
24
+ R (torch.Tensor): Sparse random projection matrix.
25
+ """
26
+ R = torch.zeros(new_rank, original_rank)
27
+ num_nonzero = int(density * original_rank)
28
+ for i in range(new_rank):
29
+ indices = torch.randperm(original_rank)[:num_nonzero]
30
+ values = torch.randn(num_nonzero)
31
+ R[i, indices] = values
32
+ return R / torch.sqrt(torch.tensor(new_rank, dtype=torch.float32))
33
+
34
+
35
+ def reduce_lora_rank_random_projection(lora_A, lora_B, new_rank=4, use_sparse=False):
36
+ """
37
+ Reduces the rank of LoRA matrices lora_A and lora_B using random projections.
38
+
39
+ Args:
40
+ lora_A (torch.Tensor): Original lora_A matrix of shape [original_rank, in_features].
41
+ lora_B (torch.Tensor): Original lora_B matrix of shape [out_features, original_rank].
42
+ new_rank (int): Desired lower rank.
43
+ use_sparse (bool): Use sparse projection matrix.
44
+
45
+ Returns:
46
+ lora_A_new (torch.Tensor): Reduced lora_A matrix of shape [new_rank, in_features].
47
+ lora_B_new (torch.Tensor): Reduced lora_B matrix of shape [out_features, new_rank].
48
+ """
49
+ original_rank = lora_A.shape[0] # Assuming lora_A.shape = [original_rank, in_features]
50
+
51
+ # Generate random projection matrix
52
+ if use_sparse:
53
+ R = sparse_random_projection_matrix(original_rank=original_rank, new_rank=new_rank)
54
+ else:
55
+ R = torch.randn(new_rank, original_rank, dtype=torch.float32) / torch.sqrt(
56
+ torch.tensor(new_rank, dtype=torch.float32)
57
+ )
58
+ R = R.to(lora_A.device, lora_A.dtype)
59
+
60
+ # Project lora_A and lora_B
61
+ lora_A_new = (R @ lora_A.to(R.dtype)).to(lora_A.dtype) # Shape: [new_rank, in_features]
62
+ lora_B_new = (lora_B.to(R.dtype) @ R.T).to(lora_B.dtype) # Shape: [out_features, new_rank]
63
+
64
+ return lora_A_new, lora_B_new
65
+
66
+
67
+ def reduce_lora_rank_state_dict_random_projection(state_dict, new_rank=4, use_sparse=False):
68
+ """
69
+ Reduces the rank of all LoRA matrices in the given state dict using random projections.
70
+
71
+ Args:
72
+ state_dict (dict): The state dict containing LoRA matrices.
73
+ new_rank (int): Desired lower rank.
74
+ use_sparse (bool): Use sparse projection matrix.
75
+
76
+ Returns:
77
+ new_state_dict (dict): State dict with reduced-rank LoRA matrices.
78
+ """
79
+ new_state_dict = state_dict.copy()
80
+ keys = list(state_dict.keys())
81
+ for key in keys:
82
+ if "lora_A.weight" in key:
83
+ # Find the corresponding lora_B
84
+ lora_A_key = key
85
+ lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
86
+ if lora_B_key in state_dict:
87
+ lora_A = state_dict[lora_A_key]
88
+ lora_B = state_dict[lora_B_key]
89
+
90
+ # Ensure tensors are on CPU for random projection
91
+ lora_A = lora_A.to("cuda")
92
+ lora_B = lora_B.to("cuda")
93
+
94
+ # Apply the rank reduction using random projection
95
+ lora_A_new, lora_B_new = reduce_lora_rank_random_projection(
96
+ lora_A, lora_B, new_rank=new_rank, use_sparse=use_sparse
97
+ )
98
+
99
+ # Update the state dict
100
+ new_state_dict[lora_A_key] = lora_A_new
101
+ new_state_dict[lora_B_key] = lora_B_new
102
+
103
+ print(f"Reduced rank of {lora_A_key} and {lora_B_key} to {new_rank}")
104
+
105
+ return new_state_dict
106
+
107
+
108
+ def compare_approximation_error(orig_state_dict, new_state_dict):
109
+ for key in orig_state_dict:
110
+ if "lora_A.weight" in key:
111
+ lora_A_key = key
112
+ lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
113
+ lora_A_old = orig_state_dict[lora_A_key]
114
+ lora_B_old = orig_state_dict[lora_B_key]
115
+ lora_A_new = new_state_dict[lora_A_key]
116
+ lora_B_new = new_state_dict[lora_B_key]
117
+
118
+ # Original delta_W
119
+ delta_W_old = (lora_B_old @ lora_A_old).to("cuda")
120
+
121
+ # Approximated delta_W
122
+ delta_W_new = lora_B_new @ lora_A_new
123
+
124
+ # Compute the approximation error
125
+ error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
126
+ print(f"Relative error for {lora_A_key}: {error.item():.6f}")
127
+
128
+
129
+ def main(
130
+ repo_id: str,
131
+ filename: str,
132
+ new_rank: int,
133
+ use_sparse: bool = False,
134
+ check_error: bool = False,
135
+ new_lora_path: str = None,
136
+ ):
137
+ # ckpt_path = hf_hub_download(repo_id="glif/how2draw", filename="How2Draw-V2_000002800.safetensors")
138
+ if new_lora_path is None:
139
+ raise ValueError("Please provide a path to serialize the converted state dict.")
140
+
141
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
142
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
143
+ new_state_dict = reduce_lora_rank_state_dict_random_projection(
144
+ original_state_dict, new_rank=new_rank, use_sparse=use_sparse
145
+ )
146
+
147
+ if check_error:
148
+ compare_approximation_error(original_state_dict, new_state_dict)
149
+
150
+ new_state_dict = {k: v.to("cpu") for k, v in new_state_dict.items()}
151
+ # safetensors.torch.save_file(new_state_dict, "How2Draw-V2_000002800_reduced_sparse.safetensors")
152
+ safetensors.torch.save(new_state_dict, new_lora_path)
153
+
154
+
155
+ if __name__ == "__main__":
156
+ fire.Fire(main)
svd_low_rank_lora.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+
4
+ Regular SVD:
5
+ python svd_low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \
6
+ --new_rank=4 --new_lora_path="How2Draw-V2_000002800_svd.safetensors"
7
+
8
+ Randomized SVD:
9
+ python svd_low_rank_lora.py --repo_id=glif/how2draw --filename="How2Draw-V2_000002800.safetensors" \
10
+ --new_rank=4 --niter=5 --new_lora_path="How2Draw-V2_000002800_svd.safetensors"
11
+ """
12
+
13
+ import torch
14
+ from huggingface_hub import hf_hub_download
15
+ import safetensors.torch
16
+ import fire
17
+
18
+
19
+ def randomized_svd(matrix, rank, niter=5):
20
+ """
21
+ Performs a randomized SVD on the given matrix.
22
+ Args:
23
+ matrix (torch.Tensor): The input matrix.
24
+ rank (int): The target rank.
25
+ niter (int): Number of iterations for power method.
26
+ Returns:
27
+ U (torch.Tensor), S (torch.Tensor), Vh (torch.Tensor)
28
+ """
29
+ # Step 1: Generate a random Gaussian matrix
30
+ omega = torch.randn(matrix.size(1), rank, device=matrix.device)
31
+
32
+ # Step 2: Form Y = A * Omega
33
+ Y = matrix @ omega
34
+
35
+ # Step 3: Orthonormalize Y using QR decomposition
36
+ Q, _ = torch.linalg.qr(Y, mode="reduced")
37
+
38
+ # Power iteration (optional, improves approximation)
39
+ for _ in range(niter):
40
+ Z = matrix.T @ Q
41
+ Q, _ = torch.linalg.qr(matrix @ Z, mode="reduced")
42
+
43
+ # Step 4: Compute B = Q^T * A
44
+ B = Q.T @ matrix
45
+
46
+ # Step 5: Compute SVD of the small matrix B
47
+ Ub, S, Vh = torch.linalg.svd(B, full_matrices=False)
48
+
49
+ # Step 6: Compute U = Q * Ub
50
+ U = Q @ Ub
51
+
52
+ return U[:, :rank], S[:rank], Vh[:rank, :]
53
+
54
+
55
+ def reduce_lora_rank(lora_A, lora_B, niter, new_rank=4):
56
+ """
57
+ Reduces the rank of LoRA matrices lora_A and lora_B with SVD, supporting truncated SVD, too.
58
+
59
+ Args:
60
+ lora_A (torch.Tensor): Original lora_A matrix of shape [original_rank, in_features].
61
+ lora_B (torch.Tensor): Original lora_B matrix of shape [out_features, original_rank].
62
+ niter (int): Number of power iterations for randomized SVD.
63
+ new_rank (int): Desired lower rank.
64
+
65
+ Returns:
66
+ lora_A_new (torch.Tensor): Reduced lora_A matrix of shape [new_rank, in_features].
67
+ lora_B_new (torch.Tensor): Reduced lora_B matrix of shape [out_features, new_rank].
68
+ """
69
+ # Compute the low-rank update matrix
70
+ dtype = lora_A.dtype
71
+ lora_A = lora_A.to("cuda", torch.float32)
72
+ lora_B = lora_B.to("cuda", torch.float32)
73
+ delta_W = lora_B @ lora_A
74
+
75
+ # Perform SVD on the update matrix
76
+ if niter is None:
77
+ U, S, Vh = torch.linalg.svd(delta_W, full_matrices=False)
78
+ # Perform randomized SVD
79
+ if niter:
80
+ U, S, Vh = randomized_svd(delta_W, rank=new_rank, niter=niter)
81
+
82
+ # Keep only the top 'new_rank' singular values and vectors
83
+ U_new = U[:, :new_rank]
84
+ S_new = S[:new_rank]
85
+ Vh_new = Vh[:new_rank, :]
86
+
87
+ # Compute the square roots of the singular values
88
+ S_sqrt = torch.sqrt(S_new)
89
+
90
+ # Compute the new lora_B and lora_A matrices
91
+ lora_B_new = U_new * S_sqrt.unsqueeze(0) # Shape: [out_features, new_rank]
92
+ lora_A_new = S_sqrt.unsqueeze(1) * Vh_new # Shape: [new_rank, in_features]
93
+
94
+ return lora_A_new.to(dtype), lora_B_new.to(dtype)
95
+
96
+
97
+ def reduce_lora_rank_state_dict(state_dict, niter, new_rank=4):
98
+ """
99
+ Reduces the rank of all LoRA matrices in the given state dict.
100
+
101
+ Args:
102
+ state_dict (dict): The state dict containing LoRA matrices.
103
+ niter (int): Number of power iterations for ranodmized SVD.
104
+ new_rank (int): Desired lower rank.
105
+
106
+ Returns:
107
+ new_state_dict (dict): State dict with reduced-rank LoRA matrices.
108
+ """
109
+ new_state_dict = state_dict.copy()
110
+ keys = list(state_dict.keys())
111
+ for key in keys:
112
+ if "lora_A.weight" in key:
113
+ # Find the corresponding lora_B
114
+ lora_A_key = key
115
+ lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
116
+ if lora_B_key in state_dict:
117
+ lora_A = state_dict[lora_A_key]
118
+ lora_B = state_dict[lora_B_key]
119
+
120
+ # Apply the rank reduction
121
+ lora_A_new, lora_B_new = reduce_lora_rank(lora_A, lora_B, niter=niter, new_rank=new_rank)
122
+
123
+ # Update the state dict
124
+ new_state_dict[lora_A_key] = lora_A_new
125
+ new_state_dict[lora_B_key] = lora_B_new
126
+
127
+ print(f"Reduced rank of {lora_A_key} and {lora_B_key} to {new_rank}")
128
+
129
+ return new_state_dict
130
+
131
+
132
+ def compare_approximation_error(orig_state_dict, new_state_dict):
133
+ for key in orig_state_dict:
134
+ if "lora_A.weight" in key:
135
+ lora_A_key = key
136
+ lora_B_key = key.replace("lora_A.weight", "lora_B.weight")
137
+ lora_A_old = orig_state_dict[lora_A_key]
138
+ lora_B_old = orig_state_dict[lora_B_key]
139
+ lora_A_new = new_state_dict[lora_A_key]
140
+ lora_B_new = new_state_dict[lora_B_key]
141
+
142
+ # Original delta_W
143
+ delta_W_old = (lora_B_old @ lora_A_old).to("cuda")
144
+
145
+ # Approximated delta_W
146
+ delta_W_new = lora_B_new @ lora_A_new
147
+
148
+ # Compute the approximation error
149
+ error = torch.norm(delta_W_old - delta_W_new, p="fro") / torch.norm(delta_W_old, p="fro")
150
+ print(f"Relative error for {lora_A_key}: {error.item():.6f}")
151
+
152
+
153
+ def main(
154
+ repo_id: str,
155
+ filename: str,
156
+ new_rank: int,
157
+ niter: int = None,
158
+ check_error: bool = False,
159
+ new_lora_path: str = None,
160
+ ):
161
+ # ckpt_path = hf_hub_download(repo_id="glif/how2draw", filename="How2Draw-V2_000002800.safetensors")
162
+ if new_lora_path is None:
163
+ raise ValueError("Please provide a path to serialize the converted state dict.")
164
+
165
+ ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
166
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
167
+ new_state_dict = reduce_lora_rank_state_dict(original_state_dict, niter=niter, new_rank=new_rank)
168
+
169
+ if check_error:
170
+ compare_approximation_error(original_state_dict, new_state_dict)
171
+
172
+ new_state_dict = {k: v.to("cpu").contiguous() for k, v in new_state_dict.items()}
173
+ # safetensors.torch.save_file(new_state_dict, "How2Draw-V2_000002800_reduced_sparse.safetensors")
174
+ safetensors.torch.save_file(new_state_dict, new_lora_path)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ fire.Fire(main)