Vis_Diff / app.py
Steelskull's picture
Update app.py
05c9c38 verified
raw
history blame
5.18 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gradio as gr
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
layer_diffs = []
layers = zip(base_model.model.layers, chat_model.model.layers)
if load_one_at_a_time:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
base_layer, chat_layer = None, None
del base_layer, chat_layer
else:
for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name): # Added model names as parameters
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs, annot=True, fmt=".9f", cmap="YlGnBu", ax=axs[i], cbar=False)
axs[i].set_title(component)
axs[i].set_xlabel("Difference")
axs[i].set_ylabel("Layer")
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers))
axs[i].invert_yaxis()
plt.tight_layout()
return fig
def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, use_auth_token=hf_token)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, use_auth_token=hf_token)
layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
fig = visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name) # Pass model names to visualization
return fig
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter base model name"),
gr.Textbox(lines=2, placeholder="Enter chat model name"),
gr.Textbox(lines=2, placeholder="Enter Hugging Face token"),
gr.Checkbox(label="Load one layer at a time")
],
outputs="image",
title="Model Weight Difference Visualizer"
)
iface.launch()