Spaces:
Runtime error
Runtime error
Fixing full spectrum.
Browse files
app.py
CHANGED
@@ -4,6 +4,10 @@ import numpy as np
|
|
4 |
import pandas as pd
|
5 |
import re
|
6 |
import torch
|
|
|
|
|
|
|
|
|
7 |
|
8 |
number_re = re.compile(r"\.[0-9]*\.")
|
9 |
|
@@ -25,10 +29,10 @@ def scatter_plot_fn(group_name):
|
|
25 |
|
26 |
|
27 |
def find_choices(state_dict):
|
|
|
|
|
28 |
global DATA
|
29 |
-
layered_tensors = [
|
30 |
-
k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2
|
31 |
-
]
|
32 |
choices = set()
|
33 |
data = []
|
34 |
for name in layered_tensors:
|
@@ -38,7 +42,7 @@ def find_choices(state_dict):
|
|
38 |
|
39 |
svdvals = torch.linalg.svdvals(state_dict[name])
|
40 |
svdvals /= svdvals.sum()
|
41 |
-
for rank, val in enumerate(svdvals.tolist()
|
42 |
data.append((name, layer, group_name, rank, val))
|
43 |
data = np.array(data)
|
44 |
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
|
@@ -71,4 +75,4 @@ with gr.Blocks() as scatter_plot:
|
|
71 |
weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
|
72 |
|
73 |
if __name__ == "__main__":
|
74 |
-
scatter_plot.launch()
|
|
|
4 |
import pandas as pd
|
5 |
import re
|
6 |
import torch
|
7 |
+
import altair as alt
|
8 |
+
|
9 |
+
|
10 |
+
alt.data_transformers.disable_max_rows()
|
11 |
|
12 |
number_re = re.compile(r"\.[0-9]*\.")
|
13 |
|
|
|
29 |
|
30 |
|
31 |
def find_choices(state_dict):
|
32 |
+
if not state_dict:
|
33 |
+
return []
|
34 |
global DATA
|
35 |
+
layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2]
|
|
|
|
|
36 |
choices = set()
|
37 |
data = []
|
38 |
for name in layered_tensors:
|
|
|
42 |
|
43 |
svdvals = torch.linalg.svdvals(state_dict[name])
|
44 |
svdvals /= svdvals.sum()
|
45 |
+
for rank, val in enumerate(svdvals.tolist()):
|
46 |
data.append((name, layer, group_name, rank, val))
|
47 |
data = np.array(data)
|
48 |
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
|
|
|
75 |
weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
+
scatter_plot.launch(share=True)
|