Spaces:
Runtime error
Runtime error
File size: 2,181 Bytes
06e7970 3bd7542 06e7970 3bd7542 06e7970 3bd7542 06e7970 3bd7542 06e7970 e5280de 06e7970 4c415fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
from transformers import pipeline
import numpy as np
import pandas as pd
import re
import torch
import altair as alt
alt.data_transformers.disable_max_rows()
number_re = re.compile(r"\.[0-9]*\.")
STATE_DICT = {}
DATA = pd.DataFrame()
def scatter_plot_fn(group_name):
global DATA
df = DATA[DATA.group_name == group_name]
return gr.LinePlot.update(
value=df,
x="rank",
y="val",
color="layer",
tooltip=["val", "rank", "layer"],
caption="",
)
def find_choices(state_dict):
if not state_dict:
return []
global DATA
layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2]
choices = set()
data = []
for name in layered_tensors:
group_name = number_re.sub(".{N}.", name)
choices.add(group_name)
layer = int(number_re.search(name).group()[1:-1])
svdvals = torch.linalg.svdvals(state_dict[name])
svdvals /= svdvals.sum()
for rank, val in enumerate(svdvals.tolist()):
data.append((name, layer, group_name, rank, val))
data = np.array(data)
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
DATA["val"] = DATA["val"].astype("float")
DATA["layer"] = DATA["layer"].astype("category")
DATA["rank"] = DATA["rank"].astype("int32")
return choices
def weights_fn(model_id):
global STATE_DICT
try:
pipe = pipeline(model=model_id)
STATE_DICT = pipe.model.state_dict()
except Exception as e:
print(e)
STATE_DICT = {}
choices = find_choices(STATE_DICT)
return gr.Dropdown.update(choices=choices)
with gr.Blocks() as scatter_plot:
with gr.Row():
with gr.Column():
model_id = gr.Textbox(label="model_id")
weights = gr.Dropdown(label="weights")
with gr.Column():
plot = gr.LinePlot(show_label=False).style(container=True)
model_id.change(weights_fn, inputs=model_id, outputs=weights)
weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
if __name__ == "__main__":
scatter_plot.launch()
|