import gradio as gr from transformers import pipeline import numpy as np import pandas as pd import re import torch import altair as alt alt.data_transformers.disable_max_rows() number_re = re.compile(r"\.[0-9]*\.") STATE_DICT = {} DATA = pd.DataFrame() def scatter_plot_fn(group_name): global DATA df = DATA[DATA.group_name == group_name] return gr.LinePlot.update( value=df, x="rank", y="val", color="layer", tooltip=["val", "rank", "layer"], caption="", ) def find_choices(state_dict): if not state_dict: return [] global DATA layered_tensors = [k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2] choices = set() data = [] for name in layered_tensors: group_name = number_re.sub(".{N}.", name) choices.add(group_name) layer = int(number_re.search(name).group()[1:-1]) svdvals = torch.linalg.svdvals(state_dict[name]) svdvals /= svdvals.sum() for rank, val in enumerate(svdvals.tolist()): data.append((name, layer, group_name, rank, val)) data = np.array(data) DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"]) DATA["val"] = DATA["val"].astype("float") DATA["layer"] = DATA["layer"].astype("category") DATA["rank"] = DATA["rank"].astype("int32") return choices def weights_fn(model_id): global STATE_DICT try: pipe = pipeline(model=model_id) STATE_DICT = pipe.model.state_dict() except Exception as e: print(e) STATE_DICT = {} choices = find_choices(STATE_DICT) return gr.Dropdown.update(choices=choices) with gr.Blocks() as scatter_plot: with gr.Row(): with gr.Column(): model_id = gr.Textbox(label="model_id") weights = gr.Dropdown(label="weights") with gr.Column(): plot = gr.LinePlot(show_label=False).style(container=True) model_id.change(weights_fn, inputs=model_id, outputs=weights) weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot) if __name__ == "__main__": scatter_plot.launch()