JacobLinCool's picture
Upload folder using huggingface_hub
3b3134b verified
import os
from itertools import combinations
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from datasets import Audio, load_dataset
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoFeatureExtractor, WhisperModel
from .config import *
model_ids = ENABLED_MODELS
# Load dataset
dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train")
if MAX_SAMPLES is not None:
dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset))))
print(f"Limited dataset to {len(dataset)} samples for testing")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using device: {device}")
def extract_layer_reps_generator(model_id, batch_size=4):
"""
Use a generator to process samples in batches, avoiding loading all hidden states into memory at once.
Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample.
"""
model = WhisperModel.from_pretrained(model_id).to(device)
feat_ext = AutoFeatureExtractor.from_pretrained(model_id)
model.eval()
for i in tqdm(
range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches"
):
batch_end = min(i + batch_size, len(dataset))
batch_samples = dataset.select(range(i, batch_end))
# Process each sample in the batch
for j, sample in enumerate(batch_samples):
audio = sample["audio"]
samples = audio["array"]
sr = audio["sampling_rate"]
inputs = feat_ext(
samples, sampling_rate=sr, return_tensors="pt"
).input_features.to(device)
with torch.no_grad():
outputs = model.encoder(
inputs, return_dict=True, output_hidden_states=True
)
# Save the full sequence for each layer and immediately move to CPU; optionally use half precision to save memory
layer_reps_for_sample = []
for hs in outputs.hidden_states:
# hs: [1, T, D] -> [T, D]
layer_rep = hs.squeeze(0)
if USE_HALF_PRECISION:
layer_rep = layer_rep.to(HALF_PRECISION_DTYPE)
layer_reps_for_sample.append(layer_rep)
yield i + j, layer_reps_for_sample
# Clean up GPU memory
del outputs, inputs
if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
torch.cuda.empty_cache()
# Clean up model memory
del model, feat_ext
if AGGRESSIVE_CLEANUP and torch.cuda.is_available():
torch.cuda.empty_cache()
def compute_linear_mse_matrix_temporal_memory_efficient(
model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4
):
"""
Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE.
Uses a generator to process in batches, avoiding loading all representations into memory at once.
Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes.
"""
print(f"Computing alignment between {model_a_id} and {model_b_id}...")
# First, get the number of layers
sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1)
_, sample_reps_a = next(sample_gen_a)
layers_a = len(sample_reps_a)
sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1)
_, sample_reps_b = next(sample_gen_b)
layers_b = len(sample_reps_b)
mse_mat = np.zeros((layers_a, layers_b))
trained_probes = {}
pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs")
# Re-initialize generators to process all samples
gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size)
gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size)
# Collect all sample representations for specified layers
reps_a_dict_all = {}
for sample_idx, layer_reps in gen_a:
reps_a_dict_all[sample_idx] = layer_reps
reps_b_dict_all = {}
for sample_idx, layer_reps in gen_b:
reps_b_dict_all[sample_idx] = layer_reps
for i in range(layers_a):
for j in range(layers_b):
# Collect all sample representations for the specified layer
reps_a_dict = {}
for sample_idx, layer_reps in reps_a_dict_all.items():
if i < len(layer_reps):
reps_a_dict[sample_idx] = layer_reps[i]
reps_b_dict = {}
for sample_idx, layer_reps in reps_b_dict_all.items():
if j < len(layer_reps):
reps_b_dict[sample_idx] = layer_reps[j]
# Concatenate representations in order
X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())]
Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())]
# Process in batches to avoid memory issues
X_cat = torch.cat(X_list, dim=0).to(device)
Y_cat = torch.cat(Y_list, dim=0).to(device)
dim_a = X_cat.shape[1]
dim_b = Y_cat.shape[1]
# For Conv1d, reshape to [Batch, Channels, Length]
X = X_cat.T.unsqueeze(0) # [1, Dim_A, Total_Tokens]
Y = Y_cat.T.unsqueeze(0) # [1, Dim_B, Total_Tokens]
# 2. Define and train linear probe (1x1 Conv)
probe = nn.Conv1d(
in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False
).to(device=device, dtype=HALF_PRECISION_DTYPE)
probe.train()
optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
loss_fn = nn.MSELoss()
for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"):
optimizer.zero_grad()
Y_pred = probe(X)
loss = loss_fn(Y_pred, Y)
loss.backward()
optimizer.step()
# 3. Record final MSE and trained probe
final_mse = loss.item()
mse_mat[i, j] = final_mse
trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"]
# Clean up memory
del (
X_cat,
Y_cat,
X,
Y,
probe,
optimizer,
reps_a_dict,
reps_b_dict,
X_list,
Y_list,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
pbar.update(1)
pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"})
pbar.close()
return mse_mat, trained_probes
if __name__ == "__main__":
print(f"Memory optimization settings:")
print(f" Batch size: {BATCH_SIZE}")
print(f" Training steps: {TRAINING_STEPS}")
if USE_HALF_PRECISION:
dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16"
print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})")
else:
print(f" Half precision: {USE_HALF_PRECISION}")
print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}")
print(f" Models: {list(model_ids.keys())}")
print(f" Dataset size: {len(dataset)} samples")
# Create results directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 2. Compare all model pairs - using memory-efficient method
model_names = list(model_ids.keys())
all_pairs = list(combinations(model_names, 2))
print(
f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..."
)
for pair_idx, (model_a, model_b) in enumerate(all_pairs):
print(
f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..."
)
# Compute linear MSE along the temporal dimension and get trained probes - memory-efficient version
mse_mat_temporal, trained_probes = (
compute_linear_mse_matrix_temporal_memory_efficient(
model_ids[model_a],
model_ids[model_b],
n_steps=TRAINING_STEPS,
lr=LEARNING_RATE,
batch_size=BATCH_SIZE,
)
)
# Save trained models
model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors"
save_file(
trained_probes,
model_save_path,
{
"from_model": model_a,
"to_model": model_b,
"from_layers": str(len(mse_mat_temporal)),
"to_layers": str(len(mse_mat_temporal[0])),
},
)
print(f"Saved trained probes to: {model_save_path}")
if SAVE_PLOTS:
# Visualize results
# Avoid log(0) by adding a small value
eps = 1e-10
log_mse_mat = -np.log10(mse_mat_temporal + eps)
plt.figure(figsize=(8, 6))
plt.imshow(
log_mse_mat, aspect="auto", origin="lower"
) # origin='lower' is more standard for matrices
plt.colorbar(label="-log10(MSE)")
plt.title(
f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}"
)
plt.xlabel(f"whisper-{model_b} layers")
plt.ylabel(f"whisper-{model_a} layers")
plt.tight_layout()
# Save visualization results
plot_save_path = (
f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png"
)
plt.savefig(plot_save_path, dpi=PLOT_DPI)
plt.close() # Close figure to save memory
print(f"Saved plot to: {plot_save_path}")
print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory")
print(
f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models"
)