graph_spectrum / app.py
Narsil's picture
Narsil HF staff
Not share.
4c415fe
raw
history blame
2.18 kB
import gradio as gr
from transformers import pipeline
import numpy as np
import pandas as pd
import re
import torch
import altair as alt
alt.data_transformers.disable_max_rows()
number_re = re.compile(r"\.[0-9]*\.")
STATE_DICT = {}
DATA = pd.DataFrame()
def scatter_plot_fn(group_name):
global DATA
df = DATA[DATA.group_name == group_name]
return gr.LinePlot.update(
value=df,
x="rank",
y="val",
color="layer",
tooltip=["val", "rank", "layer"],
caption="",
)
def find_choices(state_dict):
if not state_dict:
return []
global DATA
layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2]
choices = set()
data = []
for name in layered_tensors:
group_name = number_re.sub(".{N}.", name)
choices.add(group_name)
layer = int(number_re.search(name).group()[1:-1])
svdvals = torch.linalg.svdvals(state_dict[name])
svdvals /= svdvals.sum()
for rank, val in enumerate(svdvals.tolist()):
data.append((name, layer, group_name, rank, val))
data = np.array(data)
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
DATA["val"] = DATA["val"].astype("float")
DATA["layer"] = DATA["layer"].astype("category")
DATA["rank"] = DATA["rank"].astype("int32")
return choices
def weights_fn(model_id):
global STATE_DICT
try:
pipe = pipeline(model=model_id)
STATE_DICT = pipe.model.state_dict()
except Exception as e:
print(e)
STATE_DICT = {}
choices = find_choices(STATE_DICT)
return gr.Dropdown.update(choices=choices)
with gr.Blocks() as scatter_plot:
with gr.Row():
with gr.Column():
model_id = gr.Textbox(label="model_id")
weights = gr.Dropdown(label="weights")
with gr.Column():
plot = gr.LinePlot(show_label=False).style(container=True)
model_id.change(weights_fn, inputs=model_id, outputs=weights)
weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
if __name__ == "__main__":
scatter_plot.launch()