Spaces:
Runtime error
Runtime error
Better visualization.
Browse files
app.py
CHANGED
@@ -38,7 +38,7 @@ def find_choices(state_dict):
|
|
38 |
|
39 |
svdvals = torch.linalg.svdvals(state_dict[name])
|
40 |
svdvals /= svdvals.sum()
|
41 |
-
for rank, val in enumerate(svdvals.tolist()[:
|
42 |
data.append((name, layer, group_name, rank, val))
|
43 |
data = np.array(data)
|
44 |
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
|
@@ -63,8 +63,8 @@ def weights_fn(model_id):
|
|
63 |
with gr.Blocks() as scatter_plot:
|
64 |
with gr.Row():
|
65 |
with gr.Column():
|
66 |
-
model_id = gr.Textbox(
|
67 |
-
weights = gr.Dropdown(
|
68 |
with gr.Column():
|
69 |
plot = gr.LinePlot(show_label=False).style(container=True)
|
70 |
model_id.change(weights_fn, inputs=model_id, outputs=weights)
|
|
|
38 |
|
39 |
svdvals = torch.linalg.svdvals(state_dict[name])
|
40 |
svdvals /= svdvals.sum()
|
41 |
+
for rank, val in enumerate(svdvals.tolist()[:300]):
|
42 |
data.append((name, layer, group_name, rank, val))
|
43 |
data = np.array(data)
|
44 |
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
|
|
|
63 |
with gr.Blocks() as scatter_plot:
|
64 |
with gr.Row():
|
65 |
with gr.Column():
|
66 |
+
model_id = gr.Textbox(label="model_id")
|
67 |
+
weights = gr.Dropdown(label="weights")
|
68 |
with gr.Column():
|
69 |
plot = gr.LinePlot(show_label=False).style(container=True)
|
70 |
model_id.change(weights_fn, inputs=model_id, outputs=weights)
|