|
import os |
|
from itertools import combinations |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from datasets import Audio, load_dataset |
|
from safetensors.torch import save_file |
|
from tqdm import tqdm |
|
from transformers import AutoFeatureExtractor, WhisperModel |
|
|
|
from .config import * |
|
|
|
model_ids = ENABLED_MODELS |
|
|
|
|
|
dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train") |
|
if MAX_SAMPLES is not None: |
|
dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset)))) |
|
print(f"Limited dataset to {len(dataset)} samples for testing") |
|
|
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) |
|
|
|
device = torch.device( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" if torch.backends.mps.is_available() else "cpu" |
|
) |
|
print(f"Using device: {device}") |
|
|
|
|
|
def extract_layer_reps_generator(model_id, batch_size=4): |
|
""" |
|
Use a generator to process samples in batches, avoiding loading all hidden states into memory at once. |
|
Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample. |
|
""" |
|
model = WhisperModel.from_pretrained(model_id).to(device) |
|
feat_ext = AutoFeatureExtractor.from_pretrained(model_id) |
|
model.eval() |
|
|
|
for i in tqdm( |
|
range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches" |
|
): |
|
batch_end = min(i + batch_size, len(dataset)) |
|
batch_samples = dataset.select(range(i, batch_end)) |
|
|
|
|
|
for j, sample in enumerate(batch_samples): |
|
audio = sample["audio"] |
|
samples = audio["array"] |
|
sr = audio["sampling_rate"] |
|
|
|
inputs = feat_ext( |
|
samples, sampling_rate=sr, return_tensors="pt" |
|
).input_features.to(device) |
|
with torch.no_grad(): |
|
outputs = model.encoder( |
|
inputs, return_dict=True, output_hidden_states=True |
|
) |
|
|
|
|
|
layer_reps_for_sample = [] |
|
for hs in outputs.hidden_states: |
|
|
|
layer_rep = hs.squeeze(0) |
|
if USE_HALF_PRECISION: |
|
layer_rep = layer_rep.to(HALF_PRECISION_DTYPE) |
|
layer_reps_for_sample.append(layer_rep) |
|
|
|
yield i + j, layer_reps_for_sample |
|
|
|
|
|
del outputs, inputs |
|
if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
del model, feat_ext |
|
if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def compute_linear_mse_matrix_temporal_memory_efficient( |
|
model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4 |
|
): |
|
""" |
|
Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE. |
|
Uses a generator to process in batches, avoiding loading all representations into memory at once. |
|
Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes. |
|
""" |
|
print(f"Computing alignment between {model_a_id} and {model_b_id}...") |
|
|
|
|
|
sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1) |
|
_, sample_reps_a = next(sample_gen_a) |
|
layers_a = len(sample_reps_a) |
|
|
|
sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1) |
|
_, sample_reps_b = next(sample_gen_b) |
|
layers_b = len(sample_reps_b) |
|
|
|
mse_mat = np.zeros((layers_a, layers_b)) |
|
trained_probes = {} |
|
|
|
pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs") |
|
|
|
|
|
gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size) |
|
gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size) |
|
|
|
|
|
reps_a_dict_all = {} |
|
for sample_idx, layer_reps in gen_a: |
|
reps_a_dict_all[sample_idx] = layer_reps |
|
|
|
reps_b_dict_all = {} |
|
for sample_idx, layer_reps in gen_b: |
|
reps_b_dict_all[sample_idx] = layer_reps |
|
|
|
for i in range(layers_a): |
|
for j in range(layers_b): |
|
|
|
reps_a_dict = {} |
|
for sample_idx, layer_reps in reps_a_dict_all.items(): |
|
if i < len(layer_reps): |
|
reps_a_dict[sample_idx] = layer_reps[i] |
|
|
|
reps_b_dict = {} |
|
for sample_idx, layer_reps in reps_b_dict_all.items(): |
|
if j < len(layer_reps): |
|
reps_b_dict[sample_idx] = layer_reps[j] |
|
|
|
|
|
X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())] |
|
Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())] |
|
|
|
|
|
X_cat = torch.cat(X_list, dim=0).to(device) |
|
Y_cat = torch.cat(Y_list, dim=0).to(device) |
|
|
|
dim_a = X_cat.shape[1] |
|
dim_b = Y_cat.shape[1] |
|
|
|
|
|
X = X_cat.T.unsqueeze(0) |
|
Y = Y_cat.T.unsqueeze(0) |
|
|
|
|
|
probe = nn.Conv1d( |
|
in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False |
|
).to(device=device, dtype=HALF_PRECISION_DTYPE) |
|
probe.train() |
|
|
|
optimizer = torch.optim.Adam(probe.parameters(), lr=lr) |
|
loss_fn = nn.MSELoss() |
|
|
|
for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"): |
|
optimizer.zero_grad() |
|
Y_pred = probe(X) |
|
loss = loss_fn(Y_pred, Y) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
final_mse = loss.item() |
|
mse_mat[i, j] = final_mse |
|
trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"] |
|
|
|
|
|
del ( |
|
X_cat, |
|
Y_cat, |
|
X, |
|
Y, |
|
probe, |
|
optimizer, |
|
reps_a_dict, |
|
reps_b_dict, |
|
X_list, |
|
Y_list, |
|
) |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
pbar.update(1) |
|
pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"}) |
|
|
|
pbar.close() |
|
return mse_mat, trained_probes |
|
|
|
|
|
if __name__ == "__main__": |
|
print(f"Memory optimization settings:") |
|
print(f" Batch size: {BATCH_SIZE}") |
|
print(f" Training steps: {TRAINING_STEPS}") |
|
if USE_HALF_PRECISION: |
|
dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16" |
|
print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})") |
|
else: |
|
print(f" Half precision: {USE_HALF_PRECISION}") |
|
print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}") |
|
print(f" Models: {list(model_ids.keys())}") |
|
print(f" Dataset size: {len(dataset)} samples") |
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
model_names = list(model_ids.keys()) |
|
all_pairs = list(combinations(model_names, 2)) |
|
|
|
print( |
|
f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..." |
|
) |
|
|
|
for pair_idx, (model_a, model_b) in enumerate(all_pairs): |
|
print( |
|
f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..." |
|
) |
|
|
|
|
|
mse_mat_temporal, trained_probes = ( |
|
compute_linear_mse_matrix_temporal_memory_efficient( |
|
model_ids[model_a], |
|
model_ids[model_b], |
|
n_steps=TRAINING_STEPS, |
|
lr=LEARNING_RATE, |
|
batch_size=BATCH_SIZE, |
|
) |
|
) |
|
|
|
|
|
model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors" |
|
save_file( |
|
trained_probes, |
|
model_save_path, |
|
{ |
|
"from_model": model_a, |
|
"to_model": model_b, |
|
"from_layers": str(len(mse_mat_temporal)), |
|
"to_layers": str(len(mse_mat_temporal[0])), |
|
}, |
|
) |
|
print(f"Saved trained probes to: {model_save_path}") |
|
|
|
if SAVE_PLOTS: |
|
|
|
|
|
eps = 1e-10 |
|
log_mse_mat = -np.log10(mse_mat_temporal + eps) |
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.imshow( |
|
log_mse_mat, aspect="auto", origin="lower" |
|
) |
|
plt.colorbar(label="-log10(MSE)") |
|
plt.title( |
|
f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}" |
|
) |
|
plt.xlabel(f"whisper-{model_b} layers") |
|
plt.ylabel(f"whisper-{model_a} layers") |
|
plt.tight_layout() |
|
|
|
|
|
plot_save_path = ( |
|
f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png" |
|
) |
|
plt.savefig(plot_save_path, dpi=PLOT_DPI) |
|
plt.close() |
|
print(f"Saved plot to: {plot_save_path}") |
|
|
|
print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory") |
|
print( |
|
f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models" |
|
) |
|
|