Upload folder using huggingface_hub
Browse files- .gitattributes +9 -0
- README.md +5 -0
- requirements.txt +13 -0
- src/__init__.py +0 -0
- src/config.py +66 -0
- src/exp.py +276 -0
- whisper-alignment-results/base-to-large-probes.safetensors +3 -0
- whisper-alignment-results/base-to-medium-probes.safetensors +3 -0
- whisper-alignment-results/base-to-small-probes.safetensors +3 -0
- whisper-alignment-results/base-vs-large-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/base-vs-medium-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/base-vs-small-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/small-to-large-probes.safetensors +3 -0
- whisper-alignment-results/small-to-medium-probes.safetensors +3 -0
- whisper-alignment-results/small-vs-large-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/small-vs-medium-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-to-base-probes.safetensors +3 -0
- whisper-alignment-results/tiny-to-large-probes.safetensors +3 -0
- whisper-alignment-results/tiny-to-medium-probes.safetensors +3 -0
- whisper-alignment-results/tiny-to-small-probes.safetensors +3 -0
- whisper-alignment-results/tiny-vs-base-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-vs-large-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-vs-medium-temporal-linear-mse-log.png +3 -0
- whisper-alignment-results/tiny-vs-small-temporal-linear-mse-log.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
whisper-alignment-results/base-vs-large-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
whisper-alignment-results/base-vs-medium-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
whisper-alignment-results/base-vs-small-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
whisper-alignment-results/small-vs-large-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
whisper-alignment-results/small-vs-medium-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
whisper-alignment-results/tiny-vs-base-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
whisper-alignment-results/tiny-vs-large-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
whisper-alignment-results/tiny-vs-medium-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
whisper-alignment-results/tiny-vs-small-temporal-linear-mse-log.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment of Layer Alignment Analysis for Whisper Encoders
|
2 |
+
|
3 |
+
Analyzing and comparing the internal representations of OpenAI Whisper encoder models, designed for research on model interpretability and transferability.
|
4 |
+
|
5 |
+
All settings are adjustable in `src/config.py`.
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
joblib
|
2 |
+
matplotlib
|
3 |
+
numpy
|
4 |
+
torch
|
5 |
+
tqdm
|
6 |
+
transformers
|
7 |
+
datasets
|
8 |
+
librosa
|
9 |
+
soundfile
|
10 |
+
safetensors
|
11 |
+
|
12 |
+
isort
|
13 |
+
black
|
src/__init__.py
ADDED
File without changes
|
src/config.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Memory Configuration for Whisper Alignment Analysis
|
2 |
+
|
3 |
+
# Batch processing settings
|
4 |
+
BATCH_SIZE = 16 # Reduce if you get OOM errors, increase for faster processing
|
5 |
+
TRAINING_STEPS = 200 # Number of training steps for linear probes
|
6 |
+
LEARNING_RATE = 1e-3
|
7 |
+
|
8 |
+
# Model selection
|
9 |
+
ENABLED_MODELS = {
|
10 |
+
"tiny": "openai/whisper-tiny", # ~39M parameters
|
11 |
+
"base": "openai/whisper-base", # ~74M parameters
|
12 |
+
"small": "openai/whisper-small", # ~244M parameters
|
13 |
+
"medium": "openai/whisper-medium", # ~769M parameters
|
14 |
+
"large": "openai/whisper-large-v3-turbo", # ~1550M parameters
|
15 |
+
}
|
16 |
+
|
17 |
+
# Memory optimization settings
|
18 |
+
USE_HALF_PRECISION = (
|
19 |
+
True # Use half precision (bfloat16 preferred, float16 fallback) instead of float32
|
20 |
+
)
|
21 |
+
AGGRESSIVE_CLEANUP = False # Clear GPU cache after each operation
|
22 |
+
|
23 |
+
# Dataset settings
|
24 |
+
MAX_SAMPLES = None # Set to a number to limit dataset size for testing (e.g., 50)
|
25 |
+
|
26 |
+
# Output settings
|
27 |
+
OUTPUT_DIR = "whisper-alignment-results"
|
28 |
+
SAVE_PLOTS = True
|
29 |
+
PLOT_DPI = 300
|
30 |
+
|
31 |
+
|
32 |
+
# Half precision dtype selection (bfloat16 preferred if available, fallback to float16)
|
33 |
+
def get_half_precision_dtype():
|
34 |
+
"""
|
35 |
+
Determine the best half precision dtype based on hardware support.
|
36 |
+
bfloat16 is preferred when available as it has better numerical stability.
|
37 |
+
"""
|
38 |
+
import torch
|
39 |
+
|
40 |
+
if not USE_HALF_PRECISION:
|
41 |
+
return torch.float32
|
42 |
+
|
43 |
+
# Check if bfloat16 is supported
|
44 |
+
if torch.cuda.is_available():
|
45 |
+
# Check GPU support for bfloat16
|
46 |
+
device_capability = torch.cuda.get_device_capability()
|
47 |
+
# bfloat16 is supported on Ampere (8.x) and newer GPUs
|
48 |
+
if device_capability[0] >= 8:
|
49 |
+
return torch.bfloat16
|
50 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
51 |
+
# Apple Silicon supports bfloat16
|
52 |
+
return torch.bfloat16
|
53 |
+
elif (
|
54 |
+
hasattr(torch, "backends")
|
55 |
+
and hasattr(torch.backends, "cpu")
|
56 |
+
and hasattr(torch.backends.cpu, "supports_bfloat16")
|
57 |
+
):
|
58 |
+
# Check CPU support for bfloat16 (newer PyTorch versions)
|
59 |
+
if torch.backends.cpu.supports_bfloat16:
|
60 |
+
return torch.bfloat16
|
61 |
+
|
62 |
+
# Fallback to float16
|
63 |
+
return torch.float16
|
64 |
+
|
65 |
+
|
66 |
+
HALF_PRECISION_DTYPE = get_half_precision_dtype()
|
src/exp.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from itertools import combinations
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from datasets import Audio, load_dataset
|
9 |
+
from safetensors.torch import save_file
|
10 |
+
from tqdm import tqdm
|
11 |
+
from transformers import AutoFeatureExtractor, WhisperModel
|
12 |
+
|
13 |
+
from .config import *
|
14 |
+
|
15 |
+
model_ids = ENABLED_MODELS
|
16 |
+
|
17 |
+
# Load dataset
|
18 |
+
dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train")
|
19 |
+
if MAX_SAMPLES is not None:
|
20 |
+
dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset))))
|
21 |
+
print(f"Limited dataset to {len(dataset)} samples for testing")
|
22 |
+
|
23 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
|
24 |
+
|
25 |
+
device = torch.device(
|
26 |
+
"cuda"
|
27 |
+
if torch.cuda.is_available()
|
28 |
+
else "mps" if torch.backends.mps.is_available() else "cpu"
|
29 |
+
)
|
30 |
+
print(f"Using device: {device}")
|
31 |
+
|
32 |
+
|
33 |
+
def extract_layer_reps_generator(model_id, batch_size=4):
|
34 |
+
"""
|
35 |
+
Use a generator to process samples in batches, avoiding loading all hidden states into memory at once.
|
36 |
+
Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample.
|
37 |
+
"""
|
38 |
+
model = WhisperModel.from_pretrained(model_id).to(device)
|
39 |
+
feat_ext = AutoFeatureExtractor.from_pretrained(model_id)
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
for i in tqdm(
|
43 |
+
range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches"
|
44 |
+
):
|
45 |
+
batch_end = min(i + batch_size, len(dataset))
|
46 |
+
batch_samples = dataset.select(range(i, batch_end))
|
47 |
+
|
48 |
+
# Process each sample in the batch
|
49 |
+
for j, sample in enumerate(batch_samples):
|
50 |
+
audio = sample["audio"]
|
51 |
+
samples = audio["array"]
|
52 |
+
sr = audio["sampling_rate"]
|
53 |
+
|
54 |
+
inputs = feat_ext(
|
55 |
+
samples, sampling_rate=sr, return_tensors="pt"
|
56 |
+
).input_features.to(device)
|
57 |
+
with torch.no_grad():
|
58 |
+
outputs = model.encoder(
|
59 |
+
inputs, return_dict=True, output_hidden_states=True
|
60 |
+
)
|
61 |
+
|
62 |
+
# Save the full sequence for each layer and immediately move to CPU; optionally use half precision to save memory
|
63 |
+
layer_reps_for_sample = []
|
64 |
+
for hs in outputs.hidden_states:
|
65 |
+
# hs: [1, T, D] -> [T, D]
|
66 |
+
layer_rep = hs.squeeze(0)
|
67 |
+
if USE_HALF_PRECISION:
|
68 |
+
layer_rep = layer_rep.to(HALF_PRECISION_DTYPE)
|
69 |
+
layer_reps_for_sample.append(layer_rep)
|
70 |
+
|
71 |
+
yield i + j, layer_reps_for_sample
|
72 |
+
|
73 |
+
# Clean up GPU memory
|
74 |
+
del outputs, inputs
|
75 |
+
if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
|
78 |
+
# Clean up model memory
|
79 |
+
del model, feat_ext
|
80 |
+
if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
|
81 |
+
torch.cuda.empty_cache()
|
82 |
+
|
83 |
+
|
84 |
+
def compute_linear_mse_matrix_temporal_memory_efficient(
|
85 |
+
model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE.
|
89 |
+
Uses a generator to process in batches, avoiding loading all representations into memory at once.
|
90 |
+
Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes.
|
91 |
+
"""
|
92 |
+
print(f"Computing alignment between {model_a_id} and {model_b_id}...")
|
93 |
+
|
94 |
+
# First, get the number of layers
|
95 |
+
sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1)
|
96 |
+
_, sample_reps_a = next(sample_gen_a)
|
97 |
+
layers_a = len(sample_reps_a)
|
98 |
+
|
99 |
+
sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1)
|
100 |
+
_, sample_reps_b = next(sample_gen_b)
|
101 |
+
layers_b = len(sample_reps_b)
|
102 |
+
|
103 |
+
mse_mat = np.zeros((layers_a, layers_b))
|
104 |
+
trained_probes = {}
|
105 |
+
|
106 |
+
pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs")
|
107 |
+
|
108 |
+
# Re-initialize generators to process all samples
|
109 |
+
gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size)
|
110 |
+
gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size)
|
111 |
+
|
112 |
+
# Collect all sample representations for specified layers
|
113 |
+
reps_a_dict_all = {}
|
114 |
+
for sample_idx, layer_reps in gen_a:
|
115 |
+
reps_a_dict_all[sample_idx] = layer_reps
|
116 |
+
|
117 |
+
reps_b_dict_all = {}
|
118 |
+
for sample_idx, layer_reps in gen_b:
|
119 |
+
reps_b_dict_all[sample_idx] = layer_reps
|
120 |
+
|
121 |
+
for i in range(layers_a):
|
122 |
+
for j in range(layers_b):
|
123 |
+
# Collect all sample representations for the specified layer
|
124 |
+
reps_a_dict = {}
|
125 |
+
for sample_idx, layer_reps in reps_a_dict_all.items():
|
126 |
+
if i < len(layer_reps):
|
127 |
+
reps_a_dict[sample_idx] = layer_reps[i]
|
128 |
+
|
129 |
+
reps_b_dict = {}
|
130 |
+
for sample_idx, layer_reps in reps_b_dict_all.items():
|
131 |
+
if j < len(layer_reps):
|
132 |
+
reps_b_dict[sample_idx] = layer_reps[j]
|
133 |
+
|
134 |
+
# Concatenate representations in order
|
135 |
+
X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())]
|
136 |
+
Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())]
|
137 |
+
|
138 |
+
# Process in batches to avoid memory issues
|
139 |
+
X_cat = torch.cat(X_list, dim=0).to(device)
|
140 |
+
Y_cat = torch.cat(Y_list, dim=0).to(device)
|
141 |
+
|
142 |
+
dim_a = X_cat.shape[1]
|
143 |
+
dim_b = Y_cat.shape[1]
|
144 |
+
|
145 |
+
# For Conv1d, reshape to [Batch, Channels, Length]
|
146 |
+
X = X_cat.T.unsqueeze(0) # [1, Dim_A, Total_Tokens]
|
147 |
+
Y = Y_cat.T.unsqueeze(0) # [1, Dim_B, Total_Tokens]
|
148 |
+
|
149 |
+
# 2. Define and train linear probe (1x1 Conv)
|
150 |
+
probe = nn.Conv1d(
|
151 |
+
in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False
|
152 |
+
).to(device=device, dtype=HALF_PRECISION_DTYPE)
|
153 |
+
probe.train()
|
154 |
+
|
155 |
+
optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
|
156 |
+
loss_fn = nn.MSELoss()
|
157 |
+
|
158 |
+
for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"):
|
159 |
+
optimizer.zero_grad()
|
160 |
+
Y_pred = probe(X)
|
161 |
+
loss = loss_fn(Y_pred, Y)
|
162 |
+
loss.backward()
|
163 |
+
optimizer.step()
|
164 |
+
|
165 |
+
# 3. Record final MSE and trained probe
|
166 |
+
final_mse = loss.item()
|
167 |
+
mse_mat[i, j] = final_mse
|
168 |
+
trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"]
|
169 |
+
|
170 |
+
# Clean up memory
|
171 |
+
del (
|
172 |
+
X_cat,
|
173 |
+
Y_cat,
|
174 |
+
X,
|
175 |
+
Y,
|
176 |
+
probe,
|
177 |
+
optimizer,
|
178 |
+
reps_a_dict,
|
179 |
+
reps_b_dict,
|
180 |
+
X_list,
|
181 |
+
Y_list,
|
182 |
+
)
|
183 |
+
if torch.cuda.is_available():
|
184 |
+
torch.cuda.empty_cache()
|
185 |
+
|
186 |
+
pbar.update(1)
|
187 |
+
pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"})
|
188 |
+
|
189 |
+
pbar.close()
|
190 |
+
return mse_mat, trained_probes
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
print(f"Memory optimization settings:")
|
195 |
+
print(f" Batch size: {BATCH_SIZE}")
|
196 |
+
print(f" Training steps: {TRAINING_STEPS}")
|
197 |
+
if USE_HALF_PRECISION:
|
198 |
+
dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16"
|
199 |
+
print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})")
|
200 |
+
else:
|
201 |
+
print(f" Half precision: {USE_HALF_PRECISION}")
|
202 |
+
print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}")
|
203 |
+
print(f" Models: {list(model_ids.keys())}")
|
204 |
+
print(f" Dataset size: {len(dataset)} samples")
|
205 |
+
|
206 |
+
# Create results directory
|
207 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
208 |
+
|
209 |
+
# 2. Compare all model pairs - using memory-efficient method
|
210 |
+
model_names = list(model_ids.keys())
|
211 |
+
all_pairs = list(combinations(model_names, 2))
|
212 |
+
|
213 |
+
print(
|
214 |
+
f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..."
|
215 |
+
)
|
216 |
+
|
217 |
+
for pair_idx, (model_a, model_b) in enumerate(all_pairs):
|
218 |
+
print(
|
219 |
+
f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..."
|
220 |
+
)
|
221 |
+
|
222 |
+
# Compute linear MSE along the temporal dimension and get trained probes - memory-efficient version
|
223 |
+
mse_mat_temporal, trained_probes = (
|
224 |
+
compute_linear_mse_matrix_temporal_memory_efficient(
|
225 |
+
model_ids[model_a],
|
226 |
+
model_ids[model_b],
|
227 |
+
n_steps=TRAINING_STEPS,
|
228 |
+
lr=LEARNING_RATE,
|
229 |
+
batch_size=BATCH_SIZE,
|
230 |
+
)
|
231 |
+
)
|
232 |
+
|
233 |
+
# Save trained models
|
234 |
+
model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors"
|
235 |
+
save_file(
|
236 |
+
trained_probes,
|
237 |
+
model_save_path,
|
238 |
+
{
|
239 |
+
"from_model": model_a,
|
240 |
+
"to_model": model_b,
|
241 |
+
"from_layers": str(len(mse_mat_temporal)),
|
242 |
+
"to_layers": str(len(mse_mat_temporal[0])),
|
243 |
+
},
|
244 |
+
)
|
245 |
+
print(f"Saved trained probes to: {model_save_path}")
|
246 |
+
|
247 |
+
if SAVE_PLOTS:
|
248 |
+
# Visualize results
|
249 |
+
# Avoid log(0) by adding a small value
|
250 |
+
eps = 1e-10
|
251 |
+
log_mse_mat = -np.log10(mse_mat_temporal + eps)
|
252 |
+
|
253 |
+
plt.figure(figsize=(8, 6))
|
254 |
+
plt.imshow(
|
255 |
+
log_mse_mat, aspect="auto", origin="lower"
|
256 |
+
) # origin='lower' is more standard for matrices
|
257 |
+
plt.colorbar(label="-log10(MSE)")
|
258 |
+
plt.title(
|
259 |
+
f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}"
|
260 |
+
)
|
261 |
+
plt.xlabel(f"whisper-{model_b} layers")
|
262 |
+
plt.ylabel(f"whisper-{model_a} layers")
|
263 |
+
plt.tight_layout()
|
264 |
+
|
265 |
+
# Save visualization results
|
266 |
+
plot_save_path = (
|
267 |
+
f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png"
|
268 |
+
)
|
269 |
+
plt.savefig(plot_save_path, dpi=PLOT_DPI)
|
270 |
+
plt.close() # Close figure to save memory
|
271 |
+
print(f"Saved plot to: {plot_save_path}")
|
272 |
+
|
273 |
+
print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory")
|
274 |
+
print(
|
275 |
+
f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models"
|
276 |
+
)
|
whisper-alignment-results/base-to-large-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:010e2238c4fd6584ed9e8b3cbd9a3379633c585a3d8f4dee2617679002193809
|
3 |
+
size 302797200
|
whisper-alignment-results/base-to-medium-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f12bfaa489bcf81f3d0e51109e81e0ff2b03a31507c8742da5f4ea6e53b996b1
|
3 |
+
size 183516544
|
whisper-alignment-results/base-to-small-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab25b9c2c210ebc1e0c3a035025322f949dc449b9f9e4a119dcb21e527545425
|
3 |
+
size 71573320
|
whisper-alignment-results/base-vs-large-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/base-vs-medium-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/base-vs-small-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/small-to-large-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e025459eeb94c6209e04d348e18eb3cb3912f128ba68adc544df9b3ce8c3ae3
|
3 |
+
size 843487312
|
whisper-alignment-results/small-to-medium-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4225bd80384af8155badfbc9ace80114339e18c893c68ba787811aab776cf59a
|
3 |
+
size 511210280
|
whisper-alignment-results/small-vs-large-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/small-vs-medium-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/tiny-to-base-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b10f010c37f5b884e9b2e5a05cbb87226d3a660c51c894e4ce175bd1093da7e
|
3 |
+
size 13765648
|
whisper-alignment-results/tiny-to-large-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d363b8111d23d57e5c45a49e716a6c58c68c8f8c637175c4617ec0bed9f75cd9
|
3 |
+
size 162216440
|
whisper-alignment-results/tiny-to-medium-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8879301313af0f749dee9587fd917bf6923477ede8984b1a4a8dbedb169d828f
|
3 |
+
size 98315144
|
whisper-alignment-results/tiny-to-small-probes.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:27d2ac798074676cf69de46554572c37279fc92ba55b459f892fee4eba686ca2
|
3 |
+
size 38344296
|
whisper-alignment-results/tiny-vs-base-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/tiny-vs-large-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/tiny-vs-medium-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|
whisper-alignment-results/tiny-vs-small-temporal-linear-mse-log.png
ADDED
![]() |
Git LFS Details
|